LLVM  16.0.0git
OpenMPOpt.cpp
Go to the documentation of this file.
1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 // - Replacing globalized device memory with stack memory.
13 // - Replacing globalized device memory with shared memory.
14 // - Parallel region merging.
15 // - Transforming generic-mode device kernels to SPMD mode.
16 // - Specializing the state machine for generic-mode device kernels.
17 //
18 //===----------------------------------------------------------------------===//
19 
21 
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/Statistic.h"
26 #include "llvm/ADT/StringRef.h"
34 #include "llvm/IR/Assumptions.h"
35 #include "llvm/IR/Constants.h"
36 #include "llvm/IR/DiagnosticInfo.h"
37 #include "llvm/IR/GlobalValue.h"
38 #include "llvm/IR/GlobalVariable.h"
39 #include "llvm/IR/Instruction.h"
40 #include "llvm/IR/Instructions.h"
41 #include "llvm/IR/IntrinsicInst.h"
42 #include "llvm/IR/IntrinsicsAMDGPU.h"
43 #include "llvm/IR/IntrinsicsNVPTX.h"
44 #include "llvm/IR/LLVMContext.h"
45 #include "llvm/InitializePasses.h"
47 #include "llvm/Support/Debug.h"
48 #include "llvm/Transforms/IPO.h"
52 
53 #include <algorithm>
54 
55 using namespace llvm;
56 using namespace omp;
57 
58 #define DEBUG_TYPE "openmp-opt"
59 
61  "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
62  cl::Hidden, cl::init(false));
63 
65  "openmp-opt-enable-merging",
66  cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
67  cl::init(false));
68 
69 static cl::opt<bool>
70  DisableInternalization("openmp-opt-disable-internalization",
71  cl::desc("Disable function internalization."),
72  cl::Hidden, cl::init(false));
73 
74 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
75  cl::Hidden);
76 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
77  cl::init(false), cl::Hidden);
78 
80  "openmp-hide-memory-transfer-latency",
81  cl::desc("[WIP] Tries to hide the latency of host to device memory"
82  " transfers"),
83  cl::Hidden, cl::init(false));
84 
86  "openmp-opt-disable-deglobalization",
87  cl::desc("Disable OpenMP optimizations involving deglobalization."),
88  cl::Hidden, cl::init(false));
89 
91  "openmp-opt-disable-spmdization",
92  cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
93  cl::Hidden, cl::init(false));
94 
96  "openmp-opt-disable-folding",
97  cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
98  cl::init(false));
99 
101  "openmp-opt-disable-state-machine-rewrite",
102  cl::desc("Disable OpenMP optimizations that replace the state machine."),
103  cl::Hidden, cl::init(false));
104 
106  "openmp-opt-disable-barrier-elimination",
107  cl::desc("Disable OpenMP optimizations that eliminate barriers."),
108  cl::Hidden, cl::init(false));
109 
111  "openmp-opt-print-module-after",
112  cl::desc("Print the current module after OpenMP optimizations."),
113  cl::Hidden, cl::init(false));
114 
116  "openmp-opt-print-module-before",
117  cl::desc("Print the current module before OpenMP optimizations."),
118  cl::Hidden, cl::init(false));
119 
121  "openmp-opt-inline-device",
122  cl::desc("Inline all applicible functions on the device."), cl::Hidden,
123  cl::init(false));
124 
125 static cl::opt<bool>
126  EnableVerboseRemarks("openmp-opt-verbose-remarks",
127  cl::desc("Enables more verbose remarks."), cl::Hidden,
128  cl::init(false));
129 
130 static cl::opt<unsigned>
131  SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
132  cl::desc("Maximal number of attributor iterations."),
133  cl::init(256));
134 
135 static cl::opt<unsigned>
136  SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
137  cl::desc("Maximum amount of shared memory to use."),
139 
140 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
141  "Number of OpenMP runtime calls deduplicated");
142 STATISTIC(NumOpenMPParallelRegionsDeleted,
143  "Number of OpenMP parallel regions deleted");
144 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
145  "Number of OpenMP runtime functions identified");
146 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
147  "Number of OpenMP runtime function uses identified");
148 STATISTIC(NumOpenMPTargetRegionKernels,
149  "Number of OpenMP target region entry points (=kernels) identified");
150 STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
151  "Number of OpenMP target region entry points (=kernels) executed in "
152  "SPMD-mode instead of generic-mode");
153 STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
154  "Number of OpenMP target region entry points (=kernels) executed in "
155  "generic-mode without a state machines");
156 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
157  "Number of OpenMP target region entry points (=kernels) executed in "
158  "generic-mode with customized state machines with fallback");
159 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
160  "Number of OpenMP target region entry points (=kernels) executed in "
161  "generic-mode with customized state machines without fallback");
162 STATISTIC(
163  NumOpenMPParallelRegionsReplacedInGPUStateMachine,
164  "Number of OpenMP parallel regions replaced with ID in GPU state machines");
165 STATISTIC(NumOpenMPParallelRegionsMerged,
166  "Number of OpenMP parallel regions merged");
167 STATISTIC(NumBytesMovedToSharedMemory,
168  "Amount of memory pushed to shared memory");
169 STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
170 
171 #if !defined(NDEBUG)
172 static constexpr auto TAG = "[" DEBUG_TYPE "]";
173 #endif
174 
175 namespace {
176 
177 struct AAHeapToShared;
178 
179 struct AAICVTracker;
180 
181 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
182 /// Attributor runs.
183 struct OMPInformationCache : public InformationCache {
184  OMPInformationCache(Module &M, AnalysisGetter &AG,
187  : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
188  Kernels(Kernels) {
189 
190  OMPBuilder.initialize();
191  initializeRuntimeFunctions(M);
192  initializeInternalControlVars();
193  }
194 
195  /// Generic information that describes an internal control variable.
196  struct InternalControlVarInfo {
197  /// The kind, as described by InternalControlVar enum.
199 
200  /// The name of the ICV.
201  StringRef Name;
202 
203  /// Environment variable associated with this ICV.
204  StringRef EnvVarName;
205 
206  /// Initial value kind.
207  ICVInitValue InitKind;
208 
209  /// Initial value.
210  ConstantInt *InitValue;
211 
212  /// Setter RTL function associated with this ICV.
213  RuntimeFunction Setter;
214 
215  /// Getter RTL function associated with this ICV.
216  RuntimeFunction Getter;
217 
218  /// RTL Function corresponding to the override clause of this ICV
220  };
221 
222  /// Generic information that describes a runtime function
223  struct RuntimeFunctionInfo {
224 
225  /// The kind, as described by the RuntimeFunction enum.
227 
228  /// The name of the function.
229  StringRef Name;
230 
231  /// Flag to indicate a variadic function.
232  bool IsVarArg;
233 
234  /// The return type of the function.
235  Type *ReturnType;
236 
237  /// The argument types of the function.
238  SmallVector<Type *, 8> ArgumentTypes;
239 
240  /// The declaration if available.
241  Function *Declaration = nullptr;
242 
243  /// Uses of this runtime function per function containing the use.
244  using UseVector = SmallVector<Use *, 16>;
245 
246  /// Clear UsesMap for runtime function.
247  void clearUsesMap() { UsesMap.clear(); }
248 
249  /// Boolean conversion that is true if the runtime function was found.
250  operator bool() const { return Declaration; }
251 
252  /// Return the vector of uses in function \p F.
253  UseVector &getOrCreateUseVector(Function *F) {
254  std::shared_ptr<UseVector> &UV = UsesMap[F];
255  if (!UV)
256  UV = std::make_shared<UseVector>();
257  return *UV;
258  }
259 
260  /// Return the vector of uses in function \p F or `nullptr` if there are
261  /// none.
262  const UseVector *getUseVector(Function &F) const {
263  auto I = UsesMap.find(&F);
264  if (I != UsesMap.end())
265  return I->second.get();
266  return nullptr;
267  }
268 
269  /// Return how many functions contain uses of this runtime function.
270  size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
271 
272  /// Return the number of arguments (or the minimal number for variadic
273  /// functions).
274  size_t getNumArgs() const { return ArgumentTypes.size(); }
275 
276  /// Run the callback \p CB on each use and forget the use if the result is
277  /// true. The callback will be fed the function in which the use was
278  /// encountered as second argument.
279  void foreachUse(SmallVectorImpl<Function *> &SCC,
280  function_ref<bool(Use &, Function &)> CB) {
281  for (Function *F : SCC)
282  foreachUse(CB, F);
283  }
284 
285  /// Run the callback \p CB on each use within the function \p F and forget
286  /// the use if the result is true.
287  void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
288  SmallVector<unsigned, 8> ToBeDeleted;
289  ToBeDeleted.clear();
290 
291  unsigned Idx = 0;
292  UseVector &UV = getOrCreateUseVector(F);
293 
294  for (Use *U : UV) {
295  if (CB(*U, *F))
296  ToBeDeleted.push_back(Idx);
297  ++Idx;
298  }
299 
300  // Remove the to-be-deleted indices in reverse order as prior
301  // modifications will not modify the smaller indices.
302  while (!ToBeDeleted.empty()) {
303  unsigned Idx = ToBeDeleted.pop_back_val();
304  UV[Idx] = UV.back();
305  UV.pop_back();
306  }
307  }
308 
309  private:
310  /// Map from functions to all uses of this runtime function contained in
311  /// them.
313 
314  public:
315  /// Iterators for the uses of this runtime function.
316  decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
317  decltype(UsesMap)::iterator end() { return UsesMap.end(); }
318  };
319 
320  /// An OpenMP-IR-Builder instance
321  OpenMPIRBuilder OMPBuilder;
322 
323  /// Map from runtime function kind to the runtime function description.
324  EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
325  RuntimeFunction::OMPRTL___last>
326  RFIs;
327 
328  /// Map from function declarations/definitions to their runtime enum type.
329  DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
330 
331  /// Map from ICV kind to the ICV description.
332  EnumeratedArray<InternalControlVarInfo, InternalControlVar,
333  InternalControlVar::ICV___last>
334  ICVs;
335 
336  /// Helper to initialize all internal control variable information for those
337  /// defined in OMPKinds.def.
338  void initializeInternalControlVars() {
339 #define ICV_RT_SET(_Name, RTL) \
340  { \
341  auto &ICV = ICVs[_Name]; \
342  ICV.Setter = RTL; \
343  }
344 #define ICV_RT_GET(Name, RTL) \
345  { \
346  auto &ICV = ICVs[Name]; \
347  ICV.Getter = RTL; \
348  }
349 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
350  { \
351  auto &ICV = ICVs[Enum]; \
352  ICV.Name = _Name; \
353  ICV.Kind = Enum; \
354  ICV.InitKind = Init; \
355  ICV.EnvVarName = _EnvVarName; \
356  switch (ICV.InitKind) { \
357  case ICV_IMPLEMENTATION_DEFINED: \
358  ICV.InitValue = nullptr; \
359  break; \
360  case ICV_ZERO: \
361  ICV.InitValue = ConstantInt::get( \
362  Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
363  break; \
364  case ICV_FALSE: \
365  ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
366  break; \
367  case ICV_LAST: \
368  break; \
369  } \
370  }
371 #include "llvm/Frontend/OpenMP/OMPKinds.def"
372  }
373 
374  /// Returns true if the function declaration \p F matches the runtime
375  /// function types, that is, return type \p RTFRetType, and argument types
376  /// \p RTFArgTypes.
377  static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
378  SmallVector<Type *, 8> &RTFArgTypes) {
379  // TODO: We should output information to the user (under debug output
380  // and via remarks).
381 
382  if (!F)
383  return false;
384  if (F->getReturnType() != RTFRetType)
385  return false;
386  if (F->arg_size() != RTFArgTypes.size())
387  return false;
388 
389  auto *RTFTyIt = RTFArgTypes.begin();
390  for (Argument &Arg : F->args()) {
391  if (Arg.getType() != *RTFTyIt)
392  return false;
393 
394  ++RTFTyIt;
395  }
396 
397  return true;
398  }
399 
400  // Helper to collect all uses of the declaration in the UsesMap.
401  unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
402  unsigned NumUses = 0;
403  if (!RFI.Declaration)
404  return NumUses;
405  OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
406 
407  if (CollectStats) {
408  NumOpenMPRuntimeFunctionsIdentified += 1;
409  NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
410  }
411 
412  // TODO: We directly convert uses into proper calls and unknown uses.
413  for (Use &U : RFI.Declaration->uses()) {
414  if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
415  if (ModuleSlice.empty() || ModuleSlice.count(UserI->getFunction())) {
416  RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
417  ++NumUses;
418  }
419  } else {
420  RFI.getOrCreateUseVector(nullptr).push_back(&U);
421  ++NumUses;
422  }
423  }
424  return NumUses;
425  }
426 
427  // Helper function to recollect uses of a runtime function.
428  void recollectUsesForFunction(RuntimeFunction RTF) {
429  auto &RFI = RFIs[RTF];
430  RFI.clearUsesMap();
431  collectUses(RFI, /*CollectStats*/ false);
432  }
433 
434  // Helper function to recollect uses of all runtime functions.
435  void recollectUses() {
436  for (int Idx = 0; Idx < RFIs.size(); ++Idx)
437  recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
438  }
439 
440  // Helper function to inherit the calling convention of the function callee.
441  void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
442  if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
443  CI->setCallingConv(Fn->getCallingConv());
444  }
445 
446  /// Helper to initialize all runtime function information for those defined
447  /// in OpenMPKinds.def.
448  void initializeRuntimeFunctions(Module &M) {
449 
450  // Helper macros for handling __VA_ARGS__ in OMP_RTL
451 #define OMP_TYPE(VarName, ...) \
452  Type *VarName = OMPBuilder.VarName; \
453  (void)VarName;
454 
455 #define OMP_ARRAY_TYPE(VarName, ...) \
456  ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
457  (void)VarName##Ty; \
458  PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
459  (void)VarName##PtrTy;
460 
461 #define OMP_FUNCTION_TYPE(VarName, ...) \
462  FunctionType *VarName = OMPBuilder.VarName; \
463  (void)VarName; \
464  PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
465  (void)VarName##Ptr;
466 
467 #define OMP_STRUCT_TYPE(VarName, ...) \
468  StructType *VarName = OMPBuilder.VarName; \
469  (void)VarName; \
470  PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
471  (void)VarName##Ptr;
472 
473 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
474  { \
475  SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
476  Function *F = M.getFunction(_Name); \
477  RTLFunctions.insert(F); \
478  if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
479  RuntimeFunctionIDMap[F] = _Enum; \
480  auto &RFI = RFIs[_Enum]; \
481  RFI.Kind = _Enum; \
482  RFI.Name = _Name; \
483  RFI.IsVarArg = _IsVarArg; \
484  RFI.ReturnType = OMPBuilder._ReturnType; \
485  RFI.ArgumentTypes = std::move(ArgsTypes); \
486  RFI.Declaration = F; \
487  unsigned NumUses = collectUses(RFI); \
488  (void)NumUses; \
489  LLVM_DEBUG({ \
490  dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
491  << " found\n"; \
492  if (RFI.Declaration) \
493  dbgs() << TAG << "-> got " << NumUses << " uses in " \
494  << RFI.getNumFunctionsWithUses() \
495  << " different functions.\n"; \
496  }); \
497  } \
498  }
499 #include "llvm/Frontend/OpenMP/OMPKinds.def"
500 
501  // Remove the `noinline` attribute from `__kmpc`, `_OMP::` and `omp_`
502  // functions, except if `optnone` is present.
503  if (isOpenMPDevice(M)) {
504  for (Function &F : M) {
505  for (StringRef Prefix : {"__kmpc", "_ZN4_OMP", "omp_"})
506  if (F.hasFnAttribute(Attribute::NoInline) &&
507  F.getName().startswith(Prefix) &&
508  !F.hasFnAttribute(Attribute::OptimizeNone))
509  F.removeFnAttr(Attribute::NoInline);
510  }
511  }
512 
513  // TODO: We should attach the attributes defined in OMPKinds.def.
514  }
515 
516  /// Collection of known kernels (\see Kernel) in the module.
518 
519  /// Collection of known OpenMP runtime functions..
520  DenseSet<const Function *> RTLFunctions;
521 };
522 
523 template <typename Ty, bool InsertInvalidates = true>
524 struct BooleanStateWithSetVector : public BooleanState {
525  bool contains(const Ty &Elem) const { return Set.contains(Elem); }
526  bool insert(const Ty &Elem) {
527  if (InsertInvalidates)
529  return Set.insert(Elem);
530  }
531 
532  const Ty &operator[](int Idx) const { return Set[Idx]; }
533  bool operator==(const BooleanStateWithSetVector &RHS) const {
534  return BooleanState::operator==(RHS) && Set == RHS.Set;
535  }
536  bool operator!=(const BooleanStateWithSetVector &RHS) const {
537  return !(*this == RHS);
538  }
539 
540  bool empty() const { return Set.empty(); }
541  size_t size() const { return Set.size(); }
542 
543  /// "Clamp" this state with \p RHS.
544  BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
546  Set.insert(RHS.Set.begin(), RHS.Set.end());
547  return *this;
548  }
549 
550 private:
551  /// A set to keep track of elements.
552  SetVector<Ty> Set;
553 
554 public:
555  typename decltype(Set)::iterator begin() { return Set.begin(); }
556  typename decltype(Set)::iterator end() { return Set.end(); }
557  typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
558  typename decltype(Set)::const_iterator end() const { return Set.end(); }
559 };
560 
561 template <typename Ty, bool InsertInvalidates = true>
562 using BooleanStateWithPtrSetVector =
563  BooleanStateWithSetVector<Ty *, InsertInvalidates>;
564 
565 struct KernelInfoState : AbstractState {
566  /// Flag to track if we reached a fixpoint.
567  bool IsAtFixpoint = false;
568 
569  /// The parallel regions (identified by the outlined parallel functions) that
570  /// can be reached from the associated function.
571  BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
572  ReachedKnownParallelRegions;
573 
574  /// State to track what parallel region we might reach.
575  BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
576 
577  /// State to track if we are in SPMD-mode, assumed or know, and why we decided
578  /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
579  /// false.
580  BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
581 
582  /// The __kmpc_target_init call in this kernel, if any. If we find more than
583  /// one we abort as the kernel is malformed.
584  CallBase *KernelInitCB = nullptr;
585 
586  /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
587  /// one we abort as the kernel is malformed.
588  CallBase *KernelDeinitCB = nullptr;
589 
590  /// Flag to indicate if the associated function is a kernel entry.
591  bool IsKernelEntry = false;
592 
593  /// State to track what kernel entries can reach the associated function.
594  BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
595 
596  /// State to indicate if we can track parallel level of the associated
597  /// function. We will give up tracking if we encounter unknown caller or the
598  /// caller is __kmpc_parallel_51.
599  BooleanStateWithSetVector<uint8_t> ParallelLevels;
600 
601  /// Abstract State interface
602  ///{
603 
604  KernelInfoState() = default;
605  KernelInfoState(bool BestState) {
606  if (!BestState)
607  indicatePessimisticFixpoint();
608  }
609 
610  /// See AbstractState::isValidState(...)
611  bool isValidState() const override { return true; }
612 
613  /// See AbstractState::isAtFixpoint(...)
614  bool isAtFixpoint() const override { return IsAtFixpoint; }
615 
616  /// See AbstractState::indicatePessimisticFixpoint(...)
617  ChangeStatus indicatePessimisticFixpoint() override {
618  IsAtFixpoint = true;
619  ReachingKernelEntries.indicatePessimisticFixpoint();
620  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
621  ReachedKnownParallelRegions.indicatePessimisticFixpoint();
622  ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
623  return ChangeStatus::CHANGED;
624  }
625 
626  /// See AbstractState::indicateOptimisticFixpoint(...)
627  ChangeStatus indicateOptimisticFixpoint() override {
628  IsAtFixpoint = true;
629  ReachingKernelEntries.indicateOptimisticFixpoint();
630  SPMDCompatibilityTracker.indicateOptimisticFixpoint();
631  ReachedKnownParallelRegions.indicateOptimisticFixpoint();
632  ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
634  }
635 
636  /// Return the assumed state
637  KernelInfoState &getAssumed() { return *this; }
638  const KernelInfoState &getAssumed() const { return *this; }
639 
640  bool operator==(const KernelInfoState &RHS) const {
641  if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
642  return false;
643  if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
644  return false;
645  if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
646  return false;
647  if (ReachingKernelEntries != RHS.ReachingKernelEntries)
648  return false;
649  return true;
650  }
651 
652  /// Returns true if this kernel contains any OpenMP parallel regions.
653  bool mayContainParallelRegion() {
654  return !ReachedKnownParallelRegions.empty() ||
655  !ReachedUnknownParallelRegions.empty();
656  }
657 
658  /// Return empty set as the best state of potential values.
659  static KernelInfoState getBestState() { return KernelInfoState(true); }
660 
661  static KernelInfoState getBestState(KernelInfoState &KIS) {
662  return getBestState();
663  }
664 
665  /// Return full set as the worst state of potential values.
666  static KernelInfoState getWorstState() { return KernelInfoState(false); }
667 
668  /// "Clamp" this state with \p KIS.
669  KernelInfoState operator^=(const KernelInfoState &KIS) {
670  // Do not merge two different _init and _deinit call sites.
671  if (KIS.KernelInitCB) {
672  if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
673  llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
674  "assumptions.");
675  KernelInitCB = KIS.KernelInitCB;
676  }
677  if (KIS.KernelDeinitCB) {
678  if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
679  llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
680  "assumptions.");
681  KernelDeinitCB = KIS.KernelDeinitCB;
682  }
683  SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
684  ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
685  ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
686  return *this;
687  }
688 
689  KernelInfoState operator&=(const KernelInfoState &KIS) {
690  return (*this ^= KIS);
691  }
692 
693  ///}
694 };
695 
696 /// Used to map the values physically (in the IR) stored in an offload
697 /// array, to a vector in memory.
698 struct OffloadArray {
699  /// Physical array (in the IR).
700  AllocaInst *Array = nullptr;
701  /// Mapped values.
702  SmallVector<Value *, 8> StoredValues;
703  /// Last stores made in the offload array.
704  SmallVector<StoreInst *, 8> LastAccesses;
705 
706  OffloadArray() = default;
707 
708  /// Initializes the OffloadArray with the values stored in \p Array before
709  /// instruction \p Before is reached. Returns false if the initialization
710  /// fails.
711  /// This MUST be used immediately after the construction of the object.
712  bool initialize(AllocaInst &Array, Instruction &Before) {
713  if (!Array.getAllocatedType()->isArrayTy())
714  return false;
715 
716  if (!getValues(Array, Before))
717  return false;
718 
719  this->Array = &Array;
720  return true;
721  }
722 
723  static const unsigned DeviceIDArgNum = 1;
724  static const unsigned BasePtrsArgNum = 3;
725  static const unsigned PtrsArgNum = 4;
726  static const unsigned SizesArgNum = 5;
727 
728 private:
729  /// Traverses the BasicBlock where \p Array is, collecting the stores made to
730  /// \p Array, leaving StoredValues with the values stored before the
731  /// instruction \p Before is reached.
732  bool getValues(AllocaInst &Array, Instruction &Before) {
733  // Initialize container.
734  const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
735  StoredValues.assign(NumValues, nullptr);
736  LastAccesses.assign(NumValues, nullptr);
737 
738  // TODO: This assumes the instruction \p Before is in the same
739  // BasicBlock as Array. Make it general, for any control flow graph.
740  BasicBlock *BB = Array.getParent();
741  if (BB != Before.getParent())
742  return false;
743 
744  const DataLayout &DL = Array.getModule()->getDataLayout();
745  const unsigned int PointerSize = DL.getPointerSize();
746 
747  for (Instruction &I : *BB) {
748  if (&I == &Before)
749  break;
750 
751  if (!isa<StoreInst>(&I))
752  continue;
753 
754  auto *S = cast<StoreInst>(&I);
755  int64_t Offset = -1;
756  auto *Dst =
757  GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
758  if (Dst == &Array) {
759  int64_t Idx = Offset / PointerSize;
760  StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
761  LastAccesses[Idx] = S;
762  }
763  }
764 
765  return isFilled();
766  }
767 
768  /// Returns true if all values in StoredValues and
769  /// LastAccesses are not nullptrs.
770  bool isFilled() {
771  const unsigned NumValues = StoredValues.size();
772  for (unsigned I = 0; I < NumValues; ++I) {
773  if (!StoredValues[I] || !LastAccesses[I])
774  return false;
775  }
776 
777  return true;
778  }
779 };
780 
781 struct OpenMPOpt {
782 
783  using OptimizationRemarkGetter =
785 
786  OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
787  OptimizationRemarkGetter OREGetter,
788  OMPInformationCache &OMPInfoCache, Attributor &A)
789  : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
790  OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
791 
792  /// Check if any remarks are enabled for openmp-opt
793  bool remarksEnabled() {
794  auto &Ctx = M.getContext();
795  return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
796  }
797 
798  /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
799  bool run(bool IsModulePass) {
800  if (SCC.empty())
801  return false;
802 
803  bool Changed = false;
804 
805  LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
806  << " functions in a slice with "
807  << OMPInfoCache.ModuleSlice.size() << " functions\n");
808 
809  if (IsModulePass) {
810  Changed |= runAttributor(IsModulePass);
811 
812  // Recollect uses, in case Attributor deleted any.
813  OMPInfoCache.recollectUses();
814 
815  // TODO: This should be folded into buildCustomStateMachine.
816  Changed |= rewriteDeviceCodeStateMachine();
817 
818  if (remarksEnabled())
819  analysisGlobalization();
820 
821  Changed |= eliminateBarriers();
822  } else {
823  if (PrintICVValues)
824  printICVs();
825  if (PrintOpenMPKernels)
826  printKernels();
827 
828  Changed |= runAttributor(IsModulePass);
829 
830  // Recollect uses, in case Attributor deleted any.
831  OMPInfoCache.recollectUses();
832 
833  Changed |= deleteParallelRegions();
834 
836  Changed |= hideMemTransfersLatency();
837  Changed |= deduplicateRuntimeCalls();
839  if (mergeParallelRegions()) {
840  deduplicateRuntimeCalls();
841  Changed = true;
842  }
843  }
844 
845  Changed |= eliminateBarriers();
846  }
847 
848  return Changed;
849  }
850 
851  /// Print initial ICV values for testing.
852  /// FIXME: This should be done from the Attributor once it is added.
853  void printICVs() const {
854  InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
855  ICV_proc_bind};
856 
857  for (Function *F : SCC) {
858  for (auto ICV : ICVs) {
859  auto ICVInfo = OMPInfoCache.ICVs[ICV];
860  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
861  return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
862  << " Value: "
863  << (ICVInfo.InitValue
864  ? toString(ICVInfo.InitValue->getValue(), 10, true)
865  : "IMPLEMENTATION_DEFINED");
866  };
867 
868  emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
869  }
870  }
871  }
872 
873  /// Print OpenMP GPU kernels for testing.
874  void printKernels() const {
875  for (Function *F : SCC) {
876  if (!OMPInfoCache.Kernels.count(F))
877  continue;
878 
879  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
880  return ORA << "OpenMP GPU kernel "
881  << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
882  };
883 
884  emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
885  }
886  }
887 
888  /// Return the call if \p U is a callee use in a regular call. If \p RFI is
889  /// given it has to be the callee or a nullptr is returned.
890  static CallInst *getCallIfRegularCall(
891  Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
892  CallInst *CI = dyn_cast<CallInst>(U.getUser());
893  if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
894  (!RFI ||
895  (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
896  return CI;
897  return nullptr;
898  }
899 
900  /// Return the call if \p V is a regular call. If \p RFI is given it has to be
901  /// the callee or a nullptr is returned.
902  static CallInst *getCallIfRegularCall(
903  Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
904  CallInst *CI = dyn_cast<CallInst>(&V);
905  if (CI && !CI->hasOperandBundles() &&
906  (!RFI ||
907  (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
908  return CI;
909  return nullptr;
910  }
911 
912 private:
913  /// Merge parallel regions when it is safe.
914  bool mergeParallelRegions() {
915  const unsigned CallbackCalleeOperand = 2;
916  const unsigned CallbackFirstArgOperand = 3;
917  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
918 
919  // Check if there are any __kmpc_fork_call calls to merge.
920  OMPInformationCache::RuntimeFunctionInfo &RFI =
921  OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
922 
923  if (!RFI.Declaration)
924  return false;
925 
926  // Unmergable calls that prevent merging a parallel region.
927  OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
928  OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
929  OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
930  };
931 
932  bool Changed = false;
933  LoopInfo *LI = nullptr;
934  DominatorTree *DT = nullptr;
935 
937 
938  BasicBlock *StartBB = nullptr, *EndBB = nullptr;
939  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
940  BasicBlock *CGStartBB = CodeGenIP.getBlock();
941  BasicBlock *CGEndBB =
942  SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
943  assert(StartBB != nullptr && "StartBB should not be null");
944  CGStartBB->getTerminator()->setSuccessor(0, StartBB);
945  assert(EndBB != nullptr && "EndBB should not be null");
946  EndBB->getTerminator()->setSuccessor(0, CGEndBB);
947  };
948 
949  auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
950  Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
951  ReplacementValue = &Inner;
952  return CodeGenIP;
953  };
954 
955  auto FiniCB = [&](InsertPointTy CodeGenIP) {};
956 
957  /// Create a sequential execution region within a merged parallel region,
958  /// encapsulated in a master construct with a barrier for synchronization.
959  auto CreateSequentialRegion = [&](Function *OuterFn,
960  BasicBlock *OuterPredBB,
961  Instruction *SeqStartI,
962  Instruction *SeqEndI) {
963  // Isolate the instructions of the sequential region to a separate
964  // block.
965  BasicBlock *ParentBB = SeqStartI->getParent();
966  BasicBlock *SeqEndBB =
967  SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
968  BasicBlock *SeqAfterBB =
969  SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
970  BasicBlock *SeqStartBB =
971  SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
972 
973  assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
974  "Expected a different CFG");
975  const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
976  ParentBB->getTerminator()->eraseFromParent();
977 
978  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
979  BasicBlock *CGStartBB = CodeGenIP.getBlock();
980  BasicBlock *CGEndBB =
981  SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
982  assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
983  CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
984  assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
985  SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
986  };
987  auto FiniCB = [&](InsertPointTy CodeGenIP) {};
988 
989  // Find outputs from the sequential region to outside users and
990  // broadcast their values to them.
991  for (Instruction &I : *SeqStartBB) {
992  SmallPtrSet<Instruction *, 4> OutsideUsers;
993  for (User *Usr : I.users()) {
994  Instruction &UsrI = *cast<Instruction>(Usr);
995  // Ignore outputs to LT intrinsics, code extraction for the merged
996  // parallel region will fix them.
997  if (UsrI.isLifetimeStartOrEnd())
998  continue;
999 
1000  if (UsrI.getParent() != SeqStartBB)
1001  OutsideUsers.insert(&UsrI);
1002  }
1003 
1004  if (OutsideUsers.empty())
1005  continue;
1006 
1007  // Emit an alloca in the outer region to store the broadcasted
1008  // value.
1009  const DataLayout &DL = M.getDataLayout();
1010  AllocaInst *AllocaI = new AllocaInst(
1011  I.getType(), DL.getAllocaAddrSpace(), nullptr,
1012  I.getName() + ".seq.output.alloc", &OuterFn->front().front());
1013 
1014  // Emit a store instruction in the sequential BB to update the
1015  // value.
1016  new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
1017 
1018  // Emit a load instruction and replace the use of the output value
1019  // with it.
1020  for (Instruction *UsrI : OutsideUsers) {
1021  LoadInst *LoadI = new LoadInst(
1022  I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
1023  UsrI->replaceUsesOfWith(&I, LoadI);
1024  }
1025  }
1026 
1028  InsertPointTy(ParentBB, ParentBB->end()), DL);
1029  InsertPointTy SeqAfterIP =
1030  OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1031 
1032  OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1033 
1034  BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1035 
1036  LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1037  << "\n");
1038  };
1039 
1040  // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1041  // contained in BB and only separated by instructions that can be
1042  // redundantly executed in parallel. The block BB is split before the first
1043  // call (in MergableCIs) and after the last so the entire region we merge
1044  // into a single parallel region is contained in a single basic block
1045  // without any other instructions. We use the OpenMPIRBuilder to outline
1046  // that block and call the resulting function via __kmpc_fork_call.
1047  auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1048  BasicBlock *BB) {
1049  // TODO: Change the interface to allow single CIs expanded, e.g, to
1050  // include an outer loop.
1051  assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1052 
1053  auto Remark = [&](OptimizationRemark OR) {
1054  OR << "Parallel region merged with parallel region"
1055  << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1056  for (auto *CI : llvm::drop_begin(MergableCIs)) {
1057  OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1058  if (CI != MergableCIs.back())
1059  OR << ", ";
1060  }
1061  return OR << ".";
1062  };
1063 
1064  emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1065 
1066  Function *OriginalFn = BB->getParent();
1067  LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1068  << " parallel regions in " << OriginalFn->getName()
1069  << "\n");
1070 
1071  // Isolate the calls to merge in a separate block.
1072  EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1073  BasicBlock *AfterBB =
1074  SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1075  StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1076  "omp.par.merged");
1077 
1078  assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1079  const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1080  BB->getTerminator()->eraseFromParent();
1081 
1082  // Create sequential regions for sequential instructions that are
1083  // in-between mergable parallel regions.
1084  for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1085  It != End; ++It) {
1086  Instruction *ForkCI = *It;
1087  Instruction *NextForkCI = *(It + 1);
1088 
1089  // Continue if there are not in-between instructions.
1090  if (ForkCI->getNextNode() == NextForkCI)
1091  continue;
1092 
1093  CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1094  NextForkCI->getPrevNode());
1095  }
1096 
1097  OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1098  DL);
1099  IRBuilder<>::InsertPoint AllocaIP(
1100  &OriginalFn->getEntryBlock(),
1101  OriginalFn->getEntryBlock().getFirstInsertionPt());
1102  // Create the merged parallel region with default proc binding, to
1103  // avoid overriding binding settings, and without explicit cancellation.
1104  InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1105  Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1106  OMP_PROC_BIND_default, /* IsCancellable */ false);
1107  BranchInst::Create(AfterBB, AfterIP.getBlock());
1108 
1109  // Perform the actual outlining.
1110  OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1111 
1112  Function *OutlinedFn = MergableCIs.front()->getCaller();
1113 
1114  // Replace the __kmpc_fork_call calls with direct calls to the outlined
1115  // callbacks.
1117  for (auto *CI : MergableCIs) {
1118  Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1119  FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1120  Args.clear();
1121  Args.push_back(OutlinedFn->getArg(0));
1122  Args.push_back(OutlinedFn->getArg(1));
1123  for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1124  ++U)
1125  Args.push_back(CI->getArgOperand(U));
1126 
1127  CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
1128  if (CI->getDebugLoc())
1129  NewCI->setDebugLoc(CI->getDebugLoc());
1130 
1131  // Forward parameter attributes from the callback to the callee.
1132  for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1133  ++U)
1134  for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1135  NewCI->addParamAttr(
1136  U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1137 
1138  // Emit an explicit barrier to replace the implicit fork-join barrier.
1139  if (CI != MergableCIs.back()) {
1140  // TODO: Remove barrier if the merged parallel region includes the
1141  // 'nowait' clause.
1142  OMPInfoCache.OMPBuilder.createBarrier(
1143  InsertPointTy(NewCI->getParent(),
1144  NewCI->getNextNode()->getIterator()),
1145  OMPD_parallel);
1146  }
1147 
1148  CI->eraseFromParent();
1149  }
1150 
1151  assert(OutlinedFn != OriginalFn && "Outlining failed");
1152  CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1153  CGUpdater.reanalyzeFunction(*OriginalFn);
1154 
1155  NumOpenMPParallelRegionsMerged += MergableCIs.size();
1156 
1157  return true;
1158  };
1159 
1160  // Helper function that identifes sequences of
1161  // __kmpc_fork_call uses in a basic block.
1162  auto DetectPRsCB = [&](Use &U, Function &F) {
1163  CallInst *CI = getCallIfRegularCall(U, &RFI);
1164  BB2PRMap[CI->getParent()].insert(CI);
1165 
1166  return false;
1167  };
1168 
1169  BB2PRMap.clear();
1170  RFI.foreachUse(SCC, DetectPRsCB);
1171  SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1172  // Find mergable parallel regions within a basic block that are
1173  // safe to merge, that is any in-between instructions can safely
1174  // execute in parallel after merging.
1175  // TODO: support merging across basic-blocks.
1176  for (auto &It : BB2PRMap) {
1177  auto &CIs = It.getSecond();
1178  if (CIs.size() < 2)
1179  continue;
1180 
1181  BasicBlock *BB = It.getFirst();
1182  SmallVector<CallInst *, 4> MergableCIs;
1183 
1184  /// Returns true if the instruction is mergable, false otherwise.
1185  /// A terminator instruction is unmergable by definition since merging
1186  /// works within a BB. Instructions before the mergable region are
1187  /// mergable if they are not calls to OpenMP runtime functions that may
1188  /// set different execution parameters for subsequent parallel regions.
1189  /// Instructions in-between parallel regions are mergable if they are not
1190  /// calls to any non-intrinsic function since that may call a non-mergable
1191  /// OpenMP runtime function.
1192  auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1193  // We do not merge across BBs, hence return false (unmergable) if the
1194  // instruction is a terminator.
1195  if (I.isTerminator())
1196  return false;
1197 
1198  if (!isa<CallInst>(&I))
1199  return true;
1200 
1201  CallInst *CI = cast<CallInst>(&I);
1202  if (IsBeforeMergableRegion) {
1203  Function *CalledFunction = CI->getCalledFunction();
1204  if (!CalledFunction)
1205  return false;
1206  // Return false (unmergable) if the call before the parallel
1207  // region calls an explicit affinity (proc_bind) or number of
1208  // threads (num_threads) compiler-generated function. Those settings
1209  // may be incompatible with following parallel regions.
1210  // TODO: ICV tracking to detect compatibility.
1211  for (const auto &RFI : UnmergableCallsInfo) {
1212  if (CalledFunction == RFI.Declaration)
1213  return false;
1214  }
1215  } else {
1216  // Return false (unmergable) if there is a call instruction
1217  // in-between parallel regions when it is not an intrinsic. It
1218  // may call an unmergable OpenMP runtime function in its callpath.
1219  // TODO: Keep track of possible OpenMP calls in the callpath.
1220  if (!isa<IntrinsicInst>(CI))
1221  return false;
1222  }
1223 
1224  return true;
1225  };
1226  // Find maximal number of parallel region CIs that are safe to merge.
1227  for (auto It = BB->begin(), End = BB->end(); It != End;) {
1228  Instruction &I = *It;
1229  ++It;
1230 
1231  if (CIs.count(&I)) {
1232  MergableCIs.push_back(cast<CallInst>(&I));
1233  continue;
1234  }
1235 
1236  // Continue expanding if the instruction is mergable.
1237  if (IsMergable(I, MergableCIs.empty()))
1238  continue;
1239 
1240  // Forward the instruction iterator to skip the next parallel region
1241  // since there is an unmergable instruction which can affect it.
1242  for (; It != End; ++It) {
1243  Instruction &SkipI = *It;
1244  if (CIs.count(&SkipI)) {
1245  LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1246  << " due to " << I << "\n");
1247  ++It;
1248  break;
1249  }
1250  }
1251 
1252  // Store mergable regions found.
1253  if (MergableCIs.size() > 1) {
1254  MergableCIsVector.push_back(MergableCIs);
1255  LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1256  << " parallel regions in block " << BB->getName()
1257  << " of function " << BB->getParent()->getName()
1258  << "\n";);
1259  }
1260 
1261  MergableCIs.clear();
1262  }
1263 
1264  if (!MergableCIsVector.empty()) {
1265  Changed = true;
1266 
1267  for (auto &MergableCIs : MergableCIsVector)
1268  Merge(MergableCIs, BB);
1269  MergableCIsVector.clear();
1270  }
1271  }
1272 
1273  if (Changed) {
1274  /// Re-collect use for fork calls, emitted barrier calls, and
1275  /// any emitted master/end_master calls.
1276  OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1277  OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1278  OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1279  OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1280  }
1281 
1282  return Changed;
1283  }
1284 
1285  /// Try to delete parallel regions if possible.
1286  bool deleteParallelRegions() {
1287  const unsigned CallbackCalleeOperand = 2;
1288 
1289  OMPInformationCache::RuntimeFunctionInfo &RFI =
1290  OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1291 
1292  if (!RFI.Declaration)
1293  return false;
1294 
1295  bool Changed = false;
1296  auto DeleteCallCB = [&](Use &U, Function &) {
1297  CallInst *CI = getCallIfRegularCall(U);
1298  if (!CI)
1299  return false;
1300  auto *Fn = dyn_cast<Function>(
1301  CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1302  if (!Fn)
1303  return false;
1304  if (!Fn->onlyReadsMemory())
1305  return false;
1306  if (!Fn->hasFnAttribute(Attribute::WillReturn))
1307  return false;
1308 
1309  LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1310  << CI->getCaller()->getName() << "\n");
1311 
1312  auto Remark = [&](OptimizationRemark OR) {
1313  return OR << "Removing parallel region with no side-effects.";
1314  };
1315  emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1316 
1317  CGUpdater.removeCallSite(*CI);
1318  CI->eraseFromParent();
1319  Changed = true;
1320  ++NumOpenMPParallelRegionsDeleted;
1321  return true;
1322  };
1323 
1324  RFI.foreachUse(SCC, DeleteCallCB);
1325 
1326  return Changed;
1327  }
1328 
1329  /// Try to eliminate runtime calls by reusing existing ones.
1330  bool deduplicateRuntimeCalls() {
1331  bool Changed = false;
1332 
1333  RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1334  OMPRTL_omp_get_num_threads,
1335  OMPRTL_omp_in_parallel,
1336  OMPRTL_omp_get_cancellation,
1337  OMPRTL_omp_get_thread_limit,
1338  OMPRTL_omp_get_supported_active_levels,
1339  OMPRTL_omp_get_level,
1340  OMPRTL_omp_get_ancestor_thread_num,
1341  OMPRTL_omp_get_team_size,
1342  OMPRTL_omp_get_active_level,
1343  OMPRTL_omp_in_final,
1344  OMPRTL_omp_get_proc_bind,
1345  OMPRTL_omp_get_num_places,
1346  OMPRTL_omp_get_num_procs,
1347  OMPRTL_omp_get_place_num,
1348  OMPRTL_omp_get_partition_num_places,
1349  OMPRTL_omp_get_partition_place_nums};
1350 
1351  // Global-tid is handled separately.
1352  SmallSetVector<Value *, 16> GTIdArgs;
1353  collectGlobalThreadIdArguments(GTIdArgs);
1354  LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1355  << " global thread ID arguments\n");
1356 
1357  for (Function *F : SCC) {
1358  for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1359  Changed |= deduplicateRuntimeCalls(
1360  *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1361 
1362  // __kmpc_global_thread_num is special as we can replace it with an
1363  // argument in enough cases to make it worth trying.
1364  Value *GTIdArg = nullptr;
1365  for (Argument &Arg : F->args())
1366  if (GTIdArgs.count(&Arg)) {
1367  GTIdArg = &Arg;
1368  break;
1369  }
1370  Changed |= deduplicateRuntimeCalls(
1371  *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1372  }
1373 
1374  return Changed;
1375  }
1376 
1377  /// Tries to hide the latency of runtime calls that involve host to
1378  /// device memory transfers by splitting them into their "issue" and "wait"
1379  /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1380  /// moved downards as much as possible. The "issue" issues the memory transfer
1381  /// asynchronously, returning a handle. The "wait" waits in the returned
1382  /// handle for the memory transfer to finish.
1383  bool hideMemTransfersLatency() {
1384  auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1385  bool Changed = false;
1386  auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1387  auto *RTCall = getCallIfRegularCall(U, &RFI);
1388  if (!RTCall)
1389  return false;
1390 
1391  OffloadArray OffloadArrays[3];
1392  if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1393  return false;
1394 
1395  LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1396 
1397  // TODO: Check if can be moved upwards.
1398  bool WasSplit = false;
1399  Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1400  if (WaitMovementPoint)
1401  WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1402 
1403  Changed |= WasSplit;
1404  return WasSplit;
1405  };
1406  RFI.foreachUse(SCC, SplitMemTransfers);
1407 
1408  return Changed;
1409  }
1410 
1411  /// Eliminates redundant, aligned barriers in OpenMP offloaded kernels.
1412  /// TODO: Make this an AA and expand it to work across blocks and functions.
1413  bool eliminateBarriers() {
1414  bool Changed = false;
1415 
1417  return /*Changed=*/false;
1418 
1419  if (OMPInfoCache.Kernels.empty())
1420  return /*Changed=*/false;
1421 
1422  enum ImplicitBarrierType { IBT_ENTRY, IBT_EXIT };
1423 
1424  class BarrierInfo {
1425  Instruction *I;
1426  enum ImplicitBarrierType Type;
1427 
1428  public:
1429  BarrierInfo(enum ImplicitBarrierType Type) : I(nullptr), Type(Type) {}
1430  BarrierInfo(Instruction &I) : I(&I) {}
1431 
1432  bool isImplicit() { return !I; }
1433 
1434  bool isImplicitEntry() { return isImplicit() && Type == IBT_ENTRY; }
1435 
1436  bool isImplicitExit() { return isImplicit() && Type == IBT_EXIT; }
1437 
1438  Instruction *getInstruction() { return I; }
1439  };
1440 
1441  for (Function *Kernel : OMPInfoCache.Kernels) {
1442  for (BasicBlock &BB : *Kernel) {
1443  SmallVector<BarrierInfo, 8> BarriersInBlock;
1444  SmallPtrSet<Instruction *, 8> BarriersToBeDeleted;
1445 
1446  // Add the kernel entry implicit barrier.
1447  if (&Kernel->getEntryBlock() == &BB)
1448  BarriersInBlock.push_back(IBT_ENTRY);
1449 
1450  // Find implicit and explicit aligned barriers in the same basic block.
1451  for (Instruction &I : BB) {
1452  if (isa<ReturnInst>(I)) {
1453  // Add the implicit barrier when exiting the kernel.
1454  BarriersInBlock.push_back(IBT_EXIT);
1455  continue;
1456  }
1457  CallBase *CB = dyn_cast<CallBase>(&I);
1458  if (!CB)
1459  continue;
1460 
1461  auto IsAlignBarrierCB = [&](CallBase &CB) {
1462  switch (CB.getIntrinsicID()) {
1463  case Intrinsic::nvvm_barrier0:
1464  case Intrinsic::nvvm_barrier0_and:
1465  case Intrinsic::nvvm_barrier0_or:
1466  case Intrinsic::nvvm_barrier0_popc:
1467  return true;
1468  default:
1469  break;
1470  }
1471  return hasAssumption(CB,
1472  KnownAssumptionString("ompx_aligned_barrier"));
1473  };
1474 
1475  if (IsAlignBarrierCB(*CB)) {
1476  // Add an explicit aligned barrier.
1477  BarriersInBlock.push_back(I);
1478  }
1479  }
1480 
1481  if (BarriersInBlock.size() <= 1)
1482  continue;
1483 
1484  // A barrier in a barrier pair is removeable if all instructions
1485  // between the barriers in the pair are side-effect free modulo the
1486  // barrier operation.
1487  auto IsBarrierRemoveable = [&Kernel](BarrierInfo *StartBI,
1488  BarrierInfo *EndBI) {
1489  assert(
1490  !StartBI->isImplicitExit() &&
1491  "Expected start barrier to be other than a kernel exit barrier");
1492  assert(
1493  !EndBI->isImplicitEntry() &&
1494  "Expected end barrier to be other than a kernel entry barrier");
1495  // If StarBI instructions is null then this the implicit
1496  // kernel entry barrier, so iterate from the first instruction in the
1497  // entry block.
1498  Instruction *I = (StartBI->isImplicitEntry())
1499  ? &Kernel->getEntryBlock().front()
1500  : StartBI->getInstruction()->getNextNode();
1501  assert(I && "Expected non-null start instruction");
1502  Instruction *E = (EndBI->isImplicitExit())
1503  ? I->getParent()->getTerminator()
1504  : EndBI->getInstruction();
1505  assert(E && "Expected non-null end instruction");
1506 
1507  for (; I != E; I = I->getNextNode()) {
1508  if (!I->mayHaveSideEffects() && !I->mayReadFromMemory())
1509  continue;
1510 
1511  auto IsPotentiallyAffectedByBarrier =
1512  [](Optional<MemoryLocation> Loc) {
1513  const Value *Obj = (Loc && Loc->Ptr)
1514  ? getUnderlyingObject(Loc->Ptr)
1515  : nullptr;
1516  if (!Obj) {
1517  LLVM_DEBUG(
1518  dbgs()
1519  << "Access to unknown location requires barriers\n");
1520  return true;
1521  }
1522  if (isa<UndefValue>(Obj))
1523  return false;
1524  if (isa<AllocaInst>(Obj))
1525  return false;
1526  if (auto *GV = dyn_cast<GlobalVariable>(Obj)) {
1527  if (GV->isConstant())
1528  return false;
1529  if (GV->isThreadLocal())
1530  return false;
1531  if (GV->getAddressSpace() == (int)AddressSpace::Local)
1532  return false;
1533  if (GV->getAddressSpace() == (int)AddressSpace::Constant)
1534  return false;
1535  }
1536  LLVM_DEBUG(dbgs() << "Access to '" << *Obj
1537  << "' requires barriers\n");
1538  return true;
1539  };
1540 
1541  if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) {
1543  if (IsPotentiallyAffectedByBarrier(Loc))
1544  return false;
1545  if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(I)) {
1548  if (IsPotentiallyAffectedByBarrier(Loc))
1549  return false;
1550  }
1551  continue;
1552  }
1553 
1554  if (auto *LI = dyn_cast<LoadInst>(I))
1555  if (LI->hasMetadata(LLVMContext::MD_invariant_load))
1556  continue;
1557 
1559  if (IsPotentiallyAffectedByBarrier(Loc))
1560  return false;
1561  }
1562 
1563  return true;
1564  };
1565 
1566  // Iterate barrier pairs and remove an explicit barrier if analysis
1567  // deems it removeable.
1568  for (auto *It = BarriersInBlock.begin(),
1569  *End = BarriersInBlock.end() - 1;
1570  It != End; ++It) {
1571 
1572  BarrierInfo *StartBI = It;
1573  BarrierInfo *EndBI = (It + 1);
1574 
1575  // Cannot remove when both are implicit barriers, continue.
1576  if (StartBI->isImplicit() && EndBI->isImplicit())
1577  continue;
1578 
1579  if (!IsBarrierRemoveable(StartBI, EndBI))
1580  continue;
1581 
1582  assert(!(StartBI->isImplicit() && EndBI->isImplicit()) &&
1583  "Expected at least one explicit barrier to remove.");
1584 
1585  // Remove an explicit barrier, check first, then second.
1586  if (!StartBI->isImplicit()) {
1587  LLVM_DEBUG(dbgs() << "Remove start barrier "
1588  << *StartBI->getInstruction() << "\n");
1589  BarriersToBeDeleted.insert(StartBI->getInstruction());
1590  } else {
1591  LLVM_DEBUG(dbgs() << "Remove end barrier "
1592  << *EndBI->getInstruction() << "\n");
1593  BarriersToBeDeleted.insert(EndBI->getInstruction());
1594  }
1595  }
1596 
1597  if (BarriersToBeDeleted.empty())
1598  continue;
1599 
1600  Changed = true;
1601  for (Instruction *I : BarriersToBeDeleted) {
1602  ++NumBarriersEliminated;
1603  auto Remark = [&](OptimizationRemark OR) {
1604  return OR << "Redundant barrier eliminated.";
1605  };
1606 
1608  emitRemark<OptimizationRemark>(I, "OMP190", Remark);
1609  I->eraseFromParent();
1610  }
1611  }
1612  }
1613 
1614  return Changed;
1615  }
1616 
1617  void analysisGlobalization() {
1618  auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1619 
1620  auto CheckGlobalization = [&](Use &U, Function &Decl) {
1621  if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1622  auto Remark = [&](OptimizationRemarkMissed ORM) {
1623  return ORM
1624  << "Found thread data sharing on the GPU. "
1625  << "Expect degraded performance due to data globalization.";
1626  };
1627  emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1628  }
1629 
1630  return false;
1631  };
1632 
1633  RFI.foreachUse(SCC, CheckGlobalization);
1634  }
1635 
1636  /// Maps the values stored in the offload arrays passed as arguments to
1637  /// \p RuntimeCall into the offload arrays in \p OAs.
1638  bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1640  assert(OAs.size() == 3 && "Need space for three offload arrays!");
1641 
1642  // A runtime call that involves memory offloading looks something like:
1643  // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1644  // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1645  // ...)
1646  // So, the idea is to access the allocas that allocate space for these
1647  // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1648  // Therefore:
1649  // i8** %offload_baseptrs.
1650  Value *BasePtrsArg =
1651  RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1652  // i8** %offload_ptrs.
1653  Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1654  // i8** %offload_sizes.
1655  Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1656 
1657  // Get values stored in **offload_baseptrs.
1658  auto *V = getUnderlyingObject(BasePtrsArg);
1659  if (!isa<AllocaInst>(V))
1660  return false;
1661  auto *BasePtrsArray = cast<AllocaInst>(V);
1662  if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1663  return false;
1664 
1665  // Get values stored in **offload_baseptrs.
1666  V = getUnderlyingObject(PtrsArg);
1667  if (!isa<AllocaInst>(V))
1668  return false;
1669  auto *PtrsArray = cast<AllocaInst>(V);
1670  if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1671  return false;
1672 
1673  // Get values stored in **offload_sizes.
1674  V = getUnderlyingObject(SizesArg);
1675  // If it's a [constant] global array don't analyze it.
1676  if (isa<GlobalValue>(V))
1677  return isa<Constant>(V);
1678  if (!isa<AllocaInst>(V))
1679  return false;
1680 
1681  auto *SizesArray = cast<AllocaInst>(V);
1682  if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1683  return false;
1684 
1685  return true;
1686  }
1687 
1688  /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1689  /// For now this is a way to test that the function getValuesInOffloadArrays
1690  /// is working properly.
1691  /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1692  void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1693  assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1694 
1695  LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1696  std::string ValuesStr;
1697  raw_string_ostream Printer(ValuesStr);
1698  std::string Separator = " --- ";
1699 
1700  for (auto *BP : OAs[0].StoredValues) {
1701  BP->print(Printer);
1702  Printer << Separator;
1703  }
1704  LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1705  ValuesStr.clear();
1706 
1707  for (auto *P : OAs[1].StoredValues) {
1708  P->print(Printer);
1709  Printer << Separator;
1710  }
1711  LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1712  ValuesStr.clear();
1713 
1714  for (auto *S : OAs[2].StoredValues) {
1715  S->print(Printer);
1716  Printer << Separator;
1717  }
1718  LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1719  }
1720 
1721  /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1722  /// moved. Returns nullptr if the movement is not possible, or not worth it.
1723  Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1724  // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1725  // Make it traverse the CFG.
1726 
1727  Instruction *CurrentI = &RuntimeCall;
1728  bool IsWorthIt = false;
1729  while ((CurrentI = CurrentI->getNextNode())) {
1730 
1731  // TODO: Once we detect the regions to be offloaded we should use the
1732  // alias analysis manager to check if CurrentI may modify one of
1733  // the offloaded regions.
1734  if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1735  if (IsWorthIt)
1736  return CurrentI;
1737 
1738  return nullptr;
1739  }
1740 
1741  // FIXME: For now if we move it over anything without side effect
1742  // is worth it.
1743  IsWorthIt = true;
1744  }
1745 
1746  // Return end of BasicBlock.
1747  return RuntimeCall.getParent()->getTerminator();
1748  }
1749 
1750  /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1751  bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1752  Instruction &WaitMovementPoint) {
1753  // Create stack allocated handle (__tgt_async_info) at the beginning of the
1754  // function. Used for storing information of the async transfer, allowing to
1755  // wait on it later.
1756  auto &IRBuilder = OMPInfoCache.OMPBuilder;
1757  auto *F = RuntimeCall.getCaller();
1758  Instruction *FirstInst = &(F->getEntryBlock().front());
1759  AllocaInst *Handle = new AllocaInst(
1760  IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst);
1761 
1762  // Add "issue" runtime call declaration:
1763  // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1764  // i8**, i8**, i64*, i64*)
1765  FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1766  M, OMPRTL___tgt_target_data_begin_mapper_issue);
1767 
1768  // Change RuntimeCall call site for its asynchronous version.
1770  for (auto &Arg : RuntimeCall.args())
1771  Args.push_back(Arg.get());
1772  Args.push_back(Handle);
1773 
1774  CallInst *IssueCallsite =
1775  CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1776  OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1777  RuntimeCall.eraseFromParent();
1778 
1779  // Add "wait" runtime call declaration:
1780  // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1781  FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1782  M, OMPRTL___tgt_target_data_begin_mapper_wait);
1783 
1784  Value *WaitParams[2] = {
1785  IssueCallsite->getArgOperand(
1786  OffloadArray::DeviceIDArgNum), // device_id.
1787  Handle // handle to wait on.
1788  };
1789  CallInst *WaitCallsite = CallInst::Create(
1790  WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1791  OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1792 
1793  return true;
1794  }
1795 
1796  static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1797  bool GlobalOnly, bool &SingleChoice) {
1798  if (CurrentIdent == NextIdent)
1799  return CurrentIdent;
1800 
1801  // TODO: Figure out how to actually combine multiple debug locations. For
1802  // now we just keep an existing one if there is a single choice.
1803  if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1804  SingleChoice = !CurrentIdent;
1805  return NextIdent;
1806  }
1807  return nullptr;
1808  }
1809 
1810  /// Return an `struct ident_t*` value that represents the ones used in the
1811  /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1812  /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1813  /// return value we create one from scratch. We also do not yet combine
1814  /// information, e.g., the source locations, see combinedIdentStruct.
1815  Value *
1816  getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1817  Function &F, bool GlobalOnly) {
1818  bool SingleChoice = true;
1819  Value *Ident = nullptr;
1820  auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1821  CallInst *CI = getCallIfRegularCall(U, &RFI);
1822  if (!CI || &F != &Caller)
1823  return false;
1824  Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1825  /* GlobalOnly */ true, SingleChoice);
1826  return false;
1827  };
1828  RFI.foreachUse(SCC, CombineIdentStruct);
1829 
1830  if (!Ident || !SingleChoice) {
1831  // The IRBuilder uses the insertion block to get to the module, this is
1832  // unfortunate but we work around it for now.
1833  if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1834  OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1835  &F.getEntryBlock(), F.getEntryBlock().begin()));
1836  // Create a fallback location if non was found.
1837  // TODO: Use the debug locations of the calls instead.
1838  uint32_t SrcLocStrSize;
1839  Constant *Loc =
1840  OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1841  Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1842  }
1843  return Ident;
1844  }
1845 
1846  /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1847  /// \p ReplVal if given.
1848  bool deduplicateRuntimeCalls(Function &F,
1849  OMPInformationCache::RuntimeFunctionInfo &RFI,
1850  Value *ReplVal = nullptr) {
1851  auto *UV = RFI.getUseVector(F);
1852  if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1853  return false;
1854 
1855  LLVM_DEBUG(
1856  dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1857  << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1858 
1859  assert((!ReplVal || (isa<Argument>(ReplVal) &&
1860  cast<Argument>(ReplVal)->getParent() == &F)) &&
1861  "Unexpected replacement value!");
1862 
1863  // TODO: Use dominance to find a good position instead.
1864  auto CanBeMoved = [this](CallBase &CB) {
1865  unsigned NumArgs = CB.arg_size();
1866  if (NumArgs == 0)
1867  return true;
1868  if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1869  return false;
1870  for (unsigned U = 1; U < NumArgs; ++U)
1871  if (isa<Instruction>(CB.getArgOperand(U)))
1872  return false;
1873  return true;
1874  };
1875 
1876  if (!ReplVal) {
1877  for (Use *U : *UV)
1878  if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1879  if (!CanBeMoved(*CI))
1880  continue;
1881 
1882  // If the function is a kernel, dedup will move
1883  // the runtime call right after the kernel init callsite. Otherwise,
1884  // it will move it to the beginning of the caller function.
1885  if (isKernel(F)) {
1886  auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
1887  auto *KernelInitUV = KernelInitRFI.getUseVector(F);
1888 
1889  if (KernelInitUV->empty())
1890  continue;
1891 
1892  assert(KernelInitUV->size() == 1 &&
1893  "Expected a single __kmpc_target_init in kernel\n");
1894 
1895  CallInst *KernelInitCI =
1896  getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI);
1897  assert(KernelInitCI &&
1898  "Expected a call to __kmpc_target_init in kernel\n");
1899 
1900  CI->moveAfter(KernelInitCI);
1901  } else
1902  CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
1903  ReplVal = CI;
1904  break;
1905  }
1906  if (!ReplVal)
1907  return false;
1908  }
1909 
1910  // If we use a call as a replacement value we need to make sure the ident is
1911  // valid at the new location. For now we just pick a global one, either
1912  // existing and used by one of the calls, or created from scratch.
1913  if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1914  if (!CI->arg_empty() &&
1915  CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1916  Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1917  /* GlobalOnly */ true);
1918  CI->setArgOperand(0, Ident);
1919  }
1920  }
1921 
1922  bool Changed = false;
1923  auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1924  CallInst *CI = getCallIfRegularCall(U, &RFI);
1925  if (!CI || CI == ReplVal || &F != &Caller)
1926  return false;
1927  assert(CI->getCaller() == &F && "Unexpected call!");
1928 
1929  auto Remark = [&](OptimizationRemark OR) {
1930  return OR << "OpenMP runtime call "
1931  << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1932  };
1933  if (CI->getDebugLoc())
1934  emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1935  else
1936  emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1937 
1938  CGUpdater.removeCallSite(*CI);
1939  CI->replaceAllUsesWith(ReplVal);
1940  CI->eraseFromParent();
1941  ++NumOpenMPRuntimeCallsDeduplicated;
1942  Changed = true;
1943  return true;
1944  };
1945  RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1946 
1947  return Changed;
1948  }
1949 
1950  /// Collect arguments that represent the global thread id in \p GTIdArgs.
1951  void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1952  // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1953  // initialization. We could define an AbstractAttribute instead and
1954  // run the Attributor here once it can be run as an SCC pass.
1955 
1956  // Helper to check the argument \p ArgNo at all call sites of \p F for
1957  // a GTId.
1958  auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1959  if (!F.hasLocalLinkage())
1960  return false;
1961  for (Use &U : F.uses()) {
1962  if (CallInst *CI = getCallIfRegularCall(U)) {
1963  Value *ArgOp = CI->getArgOperand(ArgNo);
1964  if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1965  getCallIfRegularCall(
1966  *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1967  continue;
1968  }
1969  return false;
1970  }
1971  return true;
1972  };
1973 
1974  // Helper to identify uses of a GTId as GTId arguments.
1975  auto AddUserArgs = [&](Value &GTId) {
1976  for (Use &U : GTId.uses())
1977  if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1978  if (CI->isArgOperand(&U))
1979  if (Function *Callee = CI->getCalledFunction())
1980  if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1981  GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1982  };
1983 
1984  // The argument users of __kmpc_global_thread_num calls are GTIds.
1985  OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1986  OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1987 
1988  GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1989  if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1990  AddUserArgs(*CI);
1991  return false;
1992  });
1993 
1994  // Transitively search for more arguments by looking at the users of the
1995  // ones we know already. During the search the GTIdArgs vector is extended
1996  // so we cannot cache the size nor can we use a range based for.
1997  for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1998  AddUserArgs(*GTIdArgs[U]);
1999  }
2000 
2001  /// Kernel (=GPU) optimizations and utility functions
2002  ///
2003  ///{{
2004 
2005  /// Check if \p F is a kernel, hence entry point for target offloading.
2006  bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
2007 
2008  /// Cache to remember the unique kernel for a function.
2009  DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
2010 
2011  /// Find the unique kernel that will execute \p F, if any.
2012  Kernel getUniqueKernelFor(Function &F);
2013 
2014  /// Find the unique kernel that will execute \p I, if any.
2015  Kernel getUniqueKernelFor(Instruction &I) {
2016  return getUniqueKernelFor(*I.getFunction());
2017  }
2018 
2019  /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
2020  /// the cases we can avoid taking the address of a function.
2021  bool rewriteDeviceCodeStateMachine();
2022 
2023  ///
2024  ///}}
2025 
2026  /// Emit a remark generically
2027  ///
2028  /// This template function can be used to generically emit a remark. The
2029  /// RemarkKind should be one of the following:
2030  /// - OptimizationRemark to indicate a successful optimization attempt
2031  /// - OptimizationRemarkMissed to report a failed optimization attempt
2032  /// - OptimizationRemarkAnalysis to provide additional information about an
2033  /// optimization attempt
2034  ///
2035  /// The remark is built using a callback function provided by the caller that
2036  /// takes a RemarkKind as input and returns a RemarkKind.
2037  template <typename RemarkKind, typename RemarkCallBack>
2038  void emitRemark(Instruction *I, StringRef RemarkName,
2039  RemarkCallBack &&RemarkCB) const {
2040  Function *F = I->getParent()->getParent();
2041  auto &ORE = OREGetter(F);
2042 
2043  if (RemarkName.startswith("OMP"))
2044  ORE.emit([&]() {
2045  return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2046  << " [" << RemarkName << "]";
2047  });
2048  else
2049  ORE.emit(
2050  [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2051  }
2052 
2053  /// Emit a remark on a function.
2054  template <typename RemarkKind, typename RemarkCallBack>
2055  void emitRemark(Function *F, StringRef RemarkName,
2056  RemarkCallBack &&RemarkCB) const {
2057  auto &ORE = OREGetter(F);
2058 
2059  if (RemarkName.startswith("OMP"))
2060  ORE.emit([&]() {
2061  return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2062  << " [" << RemarkName << "]";
2063  });
2064  else
2065  ORE.emit(
2066  [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2067  }
2068 
2069  /// RAII struct to temporarily change an RTL function's linkage to external.
2070  /// This prevents it from being mistakenly removed by other optimizations.
2071  struct ExternalizationRAII {
2072  ExternalizationRAII(OMPInformationCache &OMPInfoCache,
2073  RuntimeFunction RFKind)
2074  : Declaration(OMPInfoCache.RFIs[RFKind].Declaration) {
2075  if (!Declaration)
2076  return;
2077 
2078  LinkageType = Declaration->getLinkage();
2079  Declaration->setLinkage(GlobalValue::ExternalLinkage);
2080  }
2081 
2082  ~ExternalizationRAII() {
2083  if (!Declaration)
2084  return;
2085 
2086  Declaration->setLinkage(LinkageType);
2087  }
2088 
2089  Function *Declaration;
2090  GlobalValue::LinkageTypes LinkageType;
2091  };
2092 
2093  /// The underlying module.
2094  Module &M;
2095 
2096  /// The SCC we are operating on.
2098 
2099  /// Callback to update the call graph, the first argument is a removed call,
2100  /// the second an optional replacement call.
2101  CallGraphUpdater &CGUpdater;
2102 
2103  /// Callback to get an OptimizationRemarkEmitter from a Function *
2104  OptimizationRemarkGetter OREGetter;
2105 
2106  /// OpenMP-specific information cache. Also Used for Attributor runs.
2107  OMPInformationCache &OMPInfoCache;
2108 
2109  /// Attributor instance.
2110  Attributor &A;
2111 
2112  /// Helper function to run Attributor on SCC.
2113  bool runAttributor(bool IsModulePass) {
2114  if (SCC.empty())
2115  return false;
2116 
2117  // Temporarily make these function have external linkage so the Attributor
2118  // doesn't remove them when we try to look them up later.
2119  ExternalizationRAII Parallel(OMPInfoCache, OMPRTL___kmpc_kernel_parallel);
2120  ExternalizationRAII EndParallel(OMPInfoCache,
2121  OMPRTL___kmpc_kernel_end_parallel);
2122  ExternalizationRAII BarrierSPMD(OMPInfoCache,
2123  OMPRTL___kmpc_barrier_simple_spmd);
2124  ExternalizationRAII BarrierGeneric(OMPInfoCache,
2125  OMPRTL___kmpc_barrier_simple_generic);
2126  ExternalizationRAII ThreadId(OMPInfoCache,
2127  OMPRTL___kmpc_get_hardware_thread_id_in_block);
2128  ExternalizationRAII NumThreads(
2129  OMPInfoCache, OMPRTL___kmpc_get_hardware_num_threads_in_block);
2130  ExternalizationRAII WarpSize(OMPInfoCache, OMPRTL___kmpc_get_warp_size);
2131 
2132  registerAAs(IsModulePass);
2133 
2134  ChangeStatus Changed = A.run();
2135 
2136  LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2137  << " functions, result: " << Changed << ".\n");
2138 
2139  return Changed == ChangeStatus::CHANGED;
2140  }
2141 
2142  void registerFoldRuntimeCall(RuntimeFunction RF);
2143 
2144  /// Populate the Attributor with abstract attribute opportunities in the
2145  /// function.
2146  void registerAAs(bool IsModulePass);
2147 };
2148 
2149 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2150  if (!OMPInfoCache.ModuleSlice.empty() && !OMPInfoCache.ModuleSlice.count(&F))
2151  return nullptr;
2152 
2153  // Use a scope to keep the lifetime of the CachedKernel short.
2154  {
2155  Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2156  if (CachedKernel)
2157  return *CachedKernel;
2158 
2159  // TODO: We should use an AA to create an (optimistic and callback
2160  // call-aware) call graph. For now we stick to simple patterns that
2161  // are less powerful, basically the worst fixpoint.
2162  if (isKernel(F)) {
2163  CachedKernel = Kernel(&F);
2164  return *CachedKernel;
2165  }
2166 
2167  CachedKernel = nullptr;
2168  if (!F.hasLocalLinkage()) {
2169 
2170  // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2171  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2172  return ORA << "Potentially unknown OpenMP target region caller.";
2173  };
2174  emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
2175 
2176  return nullptr;
2177  }
2178  }
2179 
2180  auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2181  if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2182  // Allow use in equality comparisons.
2183  if (Cmp->isEquality())
2184  return getUniqueKernelFor(*Cmp);
2185  return nullptr;
2186  }
2187  if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2188  // Allow direct calls.
2189  if (CB->isCallee(&U))
2190  return getUniqueKernelFor(*CB);
2191 
2192  OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2193  OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2194  // Allow the use in __kmpc_parallel_51 calls.
2195  if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2196  return getUniqueKernelFor(*CB);
2197  return nullptr;
2198  }
2199  // Disallow every other use.
2200  return nullptr;
2201  };
2202 
2203  // TODO: In the future we want to track more than just a unique kernel.
2204  SmallPtrSet<Kernel, 2> PotentialKernels;
2205  OMPInformationCache::foreachUse(F, [&](const Use &U) {
2206  PotentialKernels.insert(GetUniqueKernelForUse(U));
2207  });
2208 
2209  Kernel K = nullptr;
2210  if (PotentialKernels.size() == 1)
2211  K = *PotentialKernels.begin();
2212 
2213  // Cache the result.
2214  UniqueKernelMap[&F] = K;
2215 
2216  return K;
2217 }
2218 
2219 bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2220  OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2221  OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2222 
2223  bool Changed = false;
2224  if (!KernelParallelRFI)
2225  return Changed;
2226 
2227  // If we have disabled state machine changes, exit
2229  return Changed;
2230 
2231  for (Function *F : SCC) {
2232 
2233  // Check if the function is a use in a __kmpc_parallel_51 call at
2234  // all.
2235  bool UnknownUse = false;
2236  bool KernelParallelUse = false;
2237  unsigned NumDirectCalls = 0;
2238 
2239  SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2240  OMPInformationCache::foreachUse(*F, [&](Use &U) {
2241  if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2242  if (CB->isCallee(&U)) {
2243  ++NumDirectCalls;
2244  return;
2245  }
2246 
2247  if (isa<ICmpInst>(U.getUser())) {
2248  ToBeReplacedStateMachineUses.push_back(&U);
2249  return;
2250  }
2251 
2252  // Find wrapper functions that represent parallel kernels.
2253  CallInst *CI =
2254  OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2255  const unsigned int WrapperFunctionArgNo = 6;
2256  if (!KernelParallelUse && CI &&
2257  CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2258  KernelParallelUse = true;
2259  ToBeReplacedStateMachineUses.push_back(&U);
2260  return;
2261  }
2262  UnknownUse = true;
2263  });
2264 
2265  // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2266  // use.
2267  if (!KernelParallelUse)
2268  continue;
2269 
2270  // If this ever hits, we should investigate.
2271  // TODO: Checking the number of uses is not a necessary restriction and
2272  // should be lifted.
2273  if (UnknownUse || NumDirectCalls != 1 ||
2274  ToBeReplacedStateMachineUses.size() > 2) {
2275  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2276  return ORA << "Parallel region is used in "
2277  << (UnknownUse ? "unknown" : "unexpected")
2278  << " ways. Will not attempt to rewrite the state machine.";
2279  };
2280  emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2281  continue;
2282  }
2283 
2284  // Even if we have __kmpc_parallel_51 calls, we (for now) give
2285  // up if the function is not called from a unique kernel.
2286  Kernel K = getUniqueKernelFor(*F);
2287  if (!K) {
2288  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2289  return ORA << "Parallel region is not called from a unique kernel. "
2290  "Will not attempt to rewrite the state machine.";
2291  };
2292  emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2293  continue;
2294  }
2295 
2296  // We now know F is a parallel body function called only from the kernel K.
2297  // We also identified the state machine uses in which we replace the
2298  // function pointer by a new global symbol for identification purposes. This
2299  // ensures only direct calls to the function are left.
2300 
2301  Module &M = *F->getParent();
2302  Type *Int8Ty = Type::getInt8Ty(M.getContext());
2303 
2304  auto *ID = new GlobalVariable(
2305  M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2306  UndefValue::get(Int8Ty), F->getName() + ".ID");
2307 
2308  for (Use *U : ToBeReplacedStateMachineUses)
2310  ID, U->get()->getType()));
2311 
2312  ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2313 
2314  Changed = true;
2315  }
2316 
2317  return Changed;
2318 }
2319 
2320 /// Abstract Attribute for tracking ICV values.
2321 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2323  AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2324 
2325  void initialize(Attributor &A) override {
2326  Function *F = getAnchorScope();
2327  if (!F || !A.isFunctionIPOAmendable(*F))
2328  indicatePessimisticFixpoint();
2329  }
2330 
2331  /// Returns true if value is assumed to be tracked.
2332  bool isAssumedTracked() const { return getAssumed(); }
2333 
2334  /// Returns true if value is known to be tracked.
2335  bool isKnownTracked() const { return getAssumed(); }
2336 
2337  /// Create an abstract attribute biew for the position \p IRP.
2338  static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2339 
2340  /// Return the value with which \p I can be replaced for specific \p ICV.
2341  virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
2342  const Instruction *I,
2343  Attributor &A) const {
2344  return None;
2345  }
2346 
2347  /// Return an assumed unique ICV value if a single candidate is found. If
2348  /// there cannot be one, return a nullptr. If it is not clear yet, return the
2349  /// Optional::NoneType.
2350  virtual Optional<Value *>
2351  getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2352 
2353  // Currently only nthreads is being tracked.
2354  // this array will only grow with time.
2355  InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2356 
2357  /// See AbstractAttribute::getName()
2358  const std::string getName() const override { return "AAICVTracker"; }
2359 
2360  /// See AbstractAttribute::getIdAddr()
2361  const char *getIdAddr() const override { return &ID; }
2362 
2363  /// This function should return true if the type of the \p AA is AAICVTracker
2364  static bool classof(const AbstractAttribute *AA) {
2365  return (AA->getIdAddr() == &ID);
2366  }
2367 
2368  static const char ID;
2369 };
2370 
2371 struct AAICVTrackerFunction : public AAICVTracker {
2372  AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2373  : AAICVTracker(IRP, A) {}
2374 
2375  // FIXME: come up with better string.
2376  const std::string getAsStr() const override { return "ICVTrackerFunction"; }
2377 
2378  // FIXME: come up with some stats.
2379  void trackStatistics() const override {}
2380 
2381  /// We don't manifest anything for this AA.
2382  ChangeStatus manifest(Attributor &A) override {
2383  return ChangeStatus::UNCHANGED;
2384  }
2385 
2386  // Map of ICV to their values at specific program point.
2388  InternalControlVar::ICV___last>
2389  ICVReplacementValuesMap;
2390 
2391  ChangeStatus updateImpl(Attributor &A) override {
2393 
2394  Function *F = getAnchorScope();
2395 
2396  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2397 
2398  for (InternalControlVar ICV : TrackableICVs) {
2399  auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2400 
2401  auto &ValuesMap = ICVReplacementValuesMap[ICV];
2402  auto TrackValues = [&](Use &U, Function &) {
2403  CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2404  if (!CI)
2405  return false;
2406 
2407  // FIXME: handle setters with more that 1 arguments.
2408  /// Track new value.
2409  if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2410  HasChanged = ChangeStatus::CHANGED;
2411 
2412  return false;
2413  };
2414 
2415  auto CallCheck = [&](Instruction &I) {
2416  Optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2417  if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2418  HasChanged = ChangeStatus::CHANGED;
2419 
2420  return true;
2421  };
2422 
2423  // Track all changes of an ICV.
2424  SetterRFI.foreachUse(TrackValues, F);
2425 
2426  bool UsedAssumedInformation = false;
2427  A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2428  UsedAssumedInformation,
2429  /* CheckBBLivenessOnly */ true);
2430 
2431  /// TODO: Figure out a way to avoid adding entry in
2432  /// ICVReplacementValuesMap
2433  Instruction *Entry = &F->getEntryBlock().front();
2434  if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2435  ValuesMap.insert(std::make_pair(Entry, nullptr));
2436  }
2437 
2438  return HasChanged;
2439  }
2440 
2441  /// Helper to check if \p I is a call and get the value for it if it is
2442  /// unique.
2443  Optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2444  InternalControlVar &ICV) const {
2445 
2446  const auto *CB = dyn_cast<CallBase>(&I);
2447  if (!CB || CB->hasFnAttr("no_openmp") ||
2448  CB->hasFnAttr("no_openmp_routines"))
2449  return None;
2450 
2451  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2452  auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2453  auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2454  Function *CalledFunction = CB->getCalledFunction();
2455 
2456  // Indirect call, assume ICV changes.
2457  if (CalledFunction == nullptr)
2458  return nullptr;
2459  if (CalledFunction == GetterRFI.Declaration)
2460  return None;
2461  if (CalledFunction == SetterRFI.Declaration) {
2462  if (ICVReplacementValuesMap[ICV].count(&I))
2463  return ICVReplacementValuesMap[ICV].lookup(&I);
2464 
2465  return nullptr;
2466  }
2467 
2468  // Since we don't know, assume it changes the ICV.
2469  if (CalledFunction->isDeclaration())
2470  return nullptr;
2471 
2472  const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2474 
2475  if (ICVTrackingAA.isAssumedTracked()) {
2476  Optional<Value *> URV = ICVTrackingAA.getUniqueReplacementValue(ICV);
2477  if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2478  OMPInfoCache)))
2479  return URV;
2480  }
2481 
2482  // If we don't know, assume it changes.
2483  return nullptr;
2484  }
2485 
2486  // We don't check unique value for a function, so return None.
2488  getUniqueReplacementValue(InternalControlVar ICV) const override {
2489  return None;
2490  }
2491 
2492  /// Return the value with which \p I can be replaced for specific \p ICV.
2493  Optional<Value *> getReplacementValue(InternalControlVar ICV,
2494  const Instruction *I,
2495  Attributor &A) const override {
2496  const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2497  if (ValuesMap.count(I))
2498  return ValuesMap.lookup(I);
2499 
2502  Worklist.push_back(I);
2503 
2504  Optional<Value *> ReplVal;
2505 
2506  while (!Worklist.empty()) {
2507  const Instruction *CurrInst = Worklist.pop_back_val();
2508  if (!Visited.insert(CurrInst).second)
2509  continue;
2510 
2511  const BasicBlock *CurrBB = CurrInst->getParent();
2512 
2513  // Go up and look for all potential setters/calls that might change the
2514  // ICV.
2515  while ((CurrInst = CurrInst->getPrevNode())) {
2516  if (ValuesMap.count(CurrInst)) {
2517  Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2518  // Unknown value, track new.
2519  if (!ReplVal) {
2520  ReplVal = NewReplVal;
2521  break;
2522  }
2523 
2524  // If we found a new value, we can't know the icv value anymore.
2525  if (NewReplVal)
2526  if (ReplVal != NewReplVal)
2527  return nullptr;
2528 
2529  break;
2530  }
2531 
2532  Optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2533  if (!NewReplVal)
2534  continue;
2535 
2536  // Unknown value, track new.
2537  if (!ReplVal) {
2538  ReplVal = NewReplVal;
2539  break;
2540  }
2541 
2542  // if (NewReplVal.hasValue())
2543  // We found a new value, we can't know the icv value anymore.
2544  if (ReplVal != NewReplVal)
2545  return nullptr;
2546  }
2547 
2548  // If we are in the same BB and we have a value, we are done.
2549  if (CurrBB == I->getParent() && ReplVal)
2550  return ReplVal;
2551 
2552  // Go through all predecessors and add terminators for analysis.
2553  for (const BasicBlock *Pred : predecessors(CurrBB))
2554  if (const Instruction *Terminator = Pred->getTerminator())
2555  Worklist.push_back(Terminator);
2556  }
2557 
2558  return ReplVal;
2559  }
2560 };
2561 
2562 struct AAICVTrackerFunctionReturned : AAICVTracker {
2563  AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2564  : AAICVTracker(IRP, A) {}
2565 
2566  // FIXME: come up with better string.
2567  const std::string getAsStr() const override {
2568  return "ICVTrackerFunctionReturned";
2569  }
2570 
2571  // FIXME: come up with some stats.
2572  void trackStatistics() const override {}
2573 
2574  /// We don't manifest anything for this AA.
2575  ChangeStatus manifest(Attributor &A) override {
2576  return ChangeStatus::UNCHANGED;
2577  }
2578 
2579  // Map of ICV to their values at specific program point.
2581  InternalControlVar::ICV___last>
2582  ICVReplacementValuesMap;
2583 
2584  /// Return the value with which \p I can be replaced for specific \p ICV.
2586  getUniqueReplacementValue(InternalControlVar ICV) const override {
2587  return ICVReplacementValuesMap[ICV];
2588  }
2589 
2590  ChangeStatus updateImpl(Attributor &A) override {
2592  const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2593  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2594 
2595  if (!ICVTrackingAA.isAssumedTracked())
2596  return indicatePessimisticFixpoint();
2597 
2598  for (InternalControlVar ICV : TrackableICVs) {
2599  Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2600  Optional<Value *> UniqueICVValue;
2601 
2602  auto CheckReturnInst = [&](Instruction &I) {
2603  Optional<Value *> NewReplVal =
2604  ICVTrackingAA.getReplacementValue(ICV, &I, A);
2605 
2606  // If we found a second ICV value there is no unique returned value.
2607  if (UniqueICVValue && UniqueICVValue != NewReplVal)
2608  return false;
2609 
2610  UniqueICVValue = NewReplVal;
2611 
2612  return true;
2613  };
2614 
2615  bool UsedAssumedInformation = false;
2616  if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2617  UsedAssumedInformation,
2618  /* CheckBBLivenessOnly */ true))
2619  UniqueICVValue = nullptr;
2620 
2621  if (UniqueICVValue == ReplVal)
2622  continue;
2623 
2624  ReplVal = UniqueICVValue;
2625  Changed = ChangeStatus::CHANGED;
2626  }
2627 
2628  return Changed;
2629  }
2630 };
2631 
2632 struct AAICVTrackerCallSite : AAICVTracker {
2633  AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2634  : AAICVTracker(IRP, A) {}
2635 
2636  void initialize(Attributor &A) override {
2637  Function *F = getAnchorScope();
2638  if (!F || !A.isFunctionIPOAmendable(*F))
2639  indicatePessimisticFixpoint();
2640 
2641  // We only initialize this AA for getters, so we need to know which ICV it
2642  // gets.
2643  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2644  for (InternalControlVar ICV : TrackableICVs) {
2645  auto ICVInfo = OMPInfoCache.ICVs[ICV];
2646  auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2647  if (Getter.Declaration == getAssociatedFunction()) {
2648  AssociatedICV = ICVInfo.Kind;
2649  return;
2650  }
2651  }
2652 
2653  /// Unknown ICV.
2654  indicatePessimisticFixpoint();
2655  }
2656 
2657  ChangeStatus manifest(Attributor &A) override {
2658  if (!ReplVal || !*ReplVal)
2659  return ChangeStatus::UNCHANGED;
2660 
2661  A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2662  A.deleteAfterManifest(*getCtxI());
2663 
2664  return ChangeStatus::CHANGED;
2665  }
2666 
2667  // FIXME: come up with better string.
2668  const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
2669 
2670  // FIXME: come up with some stats.
2671  void trackStatistics() const override {}
2672 
2673  InternalControlVar AssociatedICV;
2674  Optional<Value *> ReplVal;
2675 
2676  ChangeStatus updateImpl(Attributor &A) override {
2677  const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2678  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2679 
2680  // We don't have any information, so we assume it changes the ICV.
2681  if (!ICVTrackingAA.isAssumedTracked())
2682  return indicatePessimisticFixpoint();
2683 
2684  Optional<Value *> NewReplVal =
2685  ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
2686 
2687  if (ReplVal == NewReplVal)
2688  return ChangeStatus::UNCHANGED;
2689 
2690  ReplVal = NewReplVal;
2691  return ChangeStatus::CHANGED;
2692  }
2693 
2694  // Return the value with which associated value can be replaced for specific
2695  // \p ICV.
2697  getUniqueReplacementValue(InternalControlVar ICV) const override {
2698  return ReplVal;
2699  }
2700 };
2701 
2702 struct AAICVTrackerCallSiteReturned : AAICVTracker {
2703  AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2704  : AAICVTracker(IRP, A) {}
2705 
2706  // FIXME: come up with better string.
2707  const std::string getAsStr() const override {
2708  return "ICVTrackerCallSiteReturned";
2709  }
2710 
2711  // FIXME: come up with some stats.
2712  void trackStatistics() const override {}
2713 
2714  /// We don't manifest anything for this AA.
2715  ChangeStatus manifest(Attributor &A) override {
2716  return ChangeStatus::UNCHANGED;
2717  }
2718 
2719  // Map of ICV to their values at specific program point.
2721  InternalControlVar::ICV___last>
2722  ICVReplacementValuesMap;
2723 
2724  /// Return the value with which associated value can be replaced for specific
2725  /// \p ICV.
2727  getUniqueReplacementValue(InternalControlVar ICV) const override {
2728  return ICVReplacementValuesMap[ICV];
2729  }
2730 
2731  ChangeStatus updateImpl(Attributor &A) override {
2733  const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2734  *this, IRPosition::returned(*getAssociatedFunction()),
2736 
2737  // We don't have any information, so we assume it changes the ICV.
2738  if (!ICVTrackingAA.isAssumedTracked())
2739  return indicatePessimisticFixpoint();
2740 
2741  for (InternalControlVar ICV : TrackableICVs) {
2742  Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2743  Optional<Value *> NewReplVal =
2744  ICVTrackingAA.getUniqueReplacementValue(ICV);
2745 
2746  if (ReplVal == NewReplVal)
2747  continue;
2748 
2749  ReplVal = NewReplVal;
2750  Changed = ChangeStatus::CHANGED;
2751  }
2752  return Changed;
2753  }
2754 };
2755 
2756 struct AAExecutionDomainFunction : public AAExecutionDomain {
2757  AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2758  : AAExecutionDomain(IRP, A) {}
2759 
2760  const std::string getAsStr() const override {
2761  return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) +
2762  "/" + std::to_string(NumBBs) + " BBs thread 0 only.";
2763  }
2764 
2765  /// See AbstractAttribute::trackStatistics().
2766  void trackStatistics() const override {}
2767 
2768  void initialize(Attributor &A) override {
2769  Function *F = getAnchorScope();
2770  for (const auto &BB : *F)
2771  SingleThreadedBBs.insert(&BB);
2772  NumBBs = SingleThreadedBBs.size();
2773  }
2774 
2775  ChangeStatus manifest(Attributor &A) override {
2776  LLVM_DEBUG({
2777  for (const BasicBlock *BB : SingleThreadedBBs)
2778  dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2779  << BB->getName() << " is executed by a single thread.\n";
2780  });
2781  return ChangeStatus::UNCHANGED;
2782  }
2783 
2784  ChangeStatus updateImpl(Attributor &A) override;
2785 
2786  /// Check if an instruction is executed by a single thread.
2787  bool isExecutedByInitialThreadOnly(const Instruction &I) const override {
2788  return isExecutedByInitialThreadOnly(*I.getParent());
2789  }
2790 
2791  bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2792  return isValidState() && SingleThreadedBBs.contains(&BB);
2793  }
2794 
2795  /// Set of basic blocks that are executed by a single thread.
2796  SmallSetVector<const BasicBlock *, 16> SingleThreadedBBs;
2797 
2798  /// Total number of basic blocks in this function.
2799  long unsigned NumBBs = 0;
2800 };
2801 
2802 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
2803  Function *F = getAnchorScope();
2805  auto NumSingleThreadedBBs = SingleThreadedBBs.size();
2806 
2807  bool AllCallSitesKnown;
2808  auto PredForCallSite = [&](AbstractCallSite ACS) {
2809  const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
2810  *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
2812  return ACS.isDirectCall() &&
2813  ExecutionDomainAA.isExecutedByInitialThreadOnly(
2814  *ACS.getInstruction());
2815  };
2816 
2817  if (!A.checkForAllCallSites(PredForCallSite, *this,
2818  /* RequiresAllCallSites */ true,
2819  AllCallSitesKnown))
2820  SingleThreadedBBs.remove(&F->getEntryBlock());
2821 
2822  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2823  auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2824 
2825  // Check if the edge into the successor block contains a condition that only
2826  // lets the main thread execute it.
2827  auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
2828  if (!Edge || !Edge->isConditional())
2829  return false;
2830  if (Edge->getSuccessor(0) != SuccessorBB)
2831  return false;
2832 
2833  auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2834  if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2835  return false;
2836 
2837  ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2838  if (!C)
2839  return false;
2840 
2841  // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2842  if (C->isAllOnesValue()) {
2843  auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2844  CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2845  if (!CB)
2846  return false;
2847  const int InitModeArgNo = 1;
2848  auto *ModeCI = dyn_cast<ConstantInt>(CB->getOperand(InitModeArgNo));
2849  return ModeCI && (ModeCI->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC);
2850  }
2851 
2852  if (C->isZero()) {
2853  // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2854  if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2855  if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2856  return true;
2857 
2858  // Match: 0 == llvm.amdgcn.workitem.id.x()
2859  if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2860  if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2861  return true;
2862  }
2863 
2864  return false;
2865  };
2866 
2867  // Merge all the predecessor states into the current basic block. A basic
2868  // block is executed by a single thread if all of its predecessors are.
2869  auto MergePredecessorStates = [&](BasicBlock *BB) {
2870  if (pred_empty(BB))
2871  return SingleThreadedBBs.contains(BB);
2872 
2873  bool IsInitialThread = true;
2874  for (BasicBlock *PredBB : predecessors(BB)) {
2875  if (!IsInitialThreadOnly(dyn_cast<BranchInst>(PredBB->getTerminator()),
2876  BB))
2877  IsInitialThread &= SingleThreadedBBs.contains(PredBB);
2878  }
2879 
2880  return IsInitialThread;
2881  };
2882 
2883  for (auto *BB : RPOT) {
2884  if (!MergePredecessorStates(BB))
2885  SingleThreadedBBs.remove(BB);
2886  }
2887 
2888  return (NumSingleThreadedBBs == SingleThreadedBBs.size())
2891 }
2892 
2893 /// Try to replace memory allocation calls called by a single thread with a
2894 /// static buffer of shared memory.
2895 struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
2897  AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2898 
2899  /// Create an abstract attribute view for the position \p IRP.
2900  static AAHeapToShared &createForPosition(const IRPosition &IRP,
2901  Attributor &A);
2902 
2903  /// Returns true if HeapToShared conversion is assumed to be possible.
2904  virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
2905 
2906  /// Returns true if HeapToShared conversion is assumed and the CB is a
2907  /// callsite to a free operation to be removed.
2908  virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
2909 
2910  /// See AbstractAttribute::getName().
2911  const std::string getName() const override { return "AAHeapToShared"; }
2912 
2913  /// See AbstractAttribute::getIdAddr().
2914  const char *getIdAddr() const override { return &ID; }
2915 
2916  /// This function should return true if the type of the \p AA is
2917  /// AAHeapToShared.
2918  static bool classof(const AbstractAttribute *AA) {
2919  return (AA->getIdAddr() == &ID);
2920  }
2921 
2922  /// Unique ID (due to the unique address)
2923  static const char ID;
2924 };
2925 
2926 struct AAHeapToSharedFunction : public AAHeapToShared {
2927  AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
2928  : AAHeapToShared(IRP, A) {}
2929 
2930  const std::string getAsStr() const override {
2931  return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
2932  " malloc calls eligible.";
2933  }
2934 
2935  /// See AbstractAttribute::trackStatistics().
2936  void trackStatistics() const override {}
2937 
2938  /// This functions finds free calls that will be removed by the
2939  /// HeapToShared transformation.
2940  void findPotentialRemovedFreeCalls(Attributor &A) {
2941  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2942  auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2943 
2944  PotentialRemovedFreeCalls.clear();
2945  // Update free call users of found malloc calls.
2946  for (CallBase *CB : MallocCalls) {
2947  SmallVector<CallBase *, 4> FreeCalls;
2948  for (auto *U : CB->users()) {
2949  CallBase *C = dyn_cast<CallBase>(U);
2950  if (C && C->getCalledFunction() == FreeRFI.Declaration)
2951  FreeCalls.push_back(C);
2952  }
2953 
2954  if (FreeCalls.size() != 1)
2955  continue;
2956 
2957  PotentialRemovedFreeCalls.insert(FreeCalls.front());
2958  }
2959  }
2960 
2961  void initialize(Attributor &A) override {
2963  indicatePessimisticFixpoint();
2964  return;
2965  }
2966 
2967  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2968  auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2969 
2971  [](const IRPosition &, const AbstractAttribute *,
2972  bool &) -> Optional<Value *> { return nullptr; };
2973  for (User *U : RFI.Declaration->users())
2974  if (CallBase *CB = dyn_cast<CallBase>(U)) {
2975  MallocCalls.insert(CB);
2976  A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
2977  SCB);
2978  }
2979 
2980  findPotentialRemovedFreeCalls(A);
2981  }
2982 
2983  bool isAssumedHeapToShared(CallBase &CB) const override {
2984  return isValidState() && MallocCalls.count(&CB);
2985  }
2986 
2987  bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
2988  return isValidState() && PotentialRemovedFreeCalls.count(&CB);
2989  }
2990 
2991  ChangeStatus manifest(Attributor &A) override {
2992  if (MallocCalls.empty())
2993  return ChangeStatus::UNCHANGED;
2994 
2995  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2996  auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2997 
2998  Function *F = getAnchorScope();
2999  auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3001 
3003  for (CallBase *CB : MallocCalls) {
3004  // Skip replacing this if HeapToStack has already claimed it.
3005  if (HS && HS->isAssumedHeapToStack(*CB))
3006  continue;
3007 
3008  // Find the unique free call to remove it.
3009  SmallVector<CallBase *, 4> FreeCalls;
3010  for (auto *U : CB->users()) {
3011  CallBase *C = dyn_cast<CallBase>(U);
3012  if (C && C->getCalledFunction() == FreeCall.Declaration)
3013  FreeCalls.push_back(C);
3014  }
3015  if (FreeCalls.size() != 1)
3016  continue;
3017 
3018  auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3019 
3020  if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3021  LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3022  << " with shared memory."
3023  << " Shared memory usage is limited to "
3024  << SharedMemoryLimit << " bytes\n");
3025  continue;
3026  }
3027 
3028  LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3029  << " with " << AllocSize->getZExtValue()
3030  << " bytes of shared memory\n");
3031 
3032  // Create a new shared memory buffer of the same size as the allocation
3033  // and replace all the uses of the original allocation with it.
3034  Module *M = CB->getModule();
3035  Type *Int8Ty = Type::getInt8Ty(M->getContext());
3036  Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3037  auto *SharedMem = new GlobalVariable(
3038  *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3039  UndefValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3041  static_cast<unsigned>(AddressSpace::Shared));
3042  auto *NewBuffer =
3043  ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
3044 
3045  auto Remark = [&](OptimizationRemark OR) {
3046  return OR << "Replaced globalized variable with "
3047  << ore::NV("SharedMemory", AllocSize->getZExtValue())
3048  << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
3049  << "of shared memory.";
3050  };
3051  A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3052 
3053  MaybeAlign Alignment = CB->getRetAlign();
3054  assert(Alignment &&
3055  "HeapToShared on allocation without alignment attribute");
3056  SharedMem->setAlignment(MaybeAlign(Alignment));
3057 
3058  A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3059  A.deleteAfterManifest(*CB);
3060  A.deleteAfterManifest(*FreeCalls.front());
3061 
3062  SharedMemoryUsed += AllocSize->getZExtValue();
3063  NumBytesMovedToSharedMemory = SharedMemoryUsed;
3064  Changed = ChangeStatus::CHANGED;
3065  }
3066 
3067  return Changed;
3068  }
3069 
3070  ChangeStatus updateImpl(Attributor &A) override {
3071  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3072  auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3073  Function *F = getAnchorScope();
3074 
3075  auto NumMallocCalls = MallocCalls.size();
3076 
3077  // Only consider malloc calls executed by a single thread with a constant.
3078  for (User *U : RFI.Declaration->users()) {
3079  const auto &ED = A.getAAFor<AAExecutionDomain>(
3081  if (CallBase *CB = dyn_cast<CallBase>(U))
3082  if (!isa<ConstantInt>(CB->getArgOperand(0)) ||
3083  !ED.isExecutedByInitialThreadOnly(*CB))
3084  MallocCalls.remove(CB);
3085  }
3086 
3087  findPotentialRemovedFreeCalls(A);
3088 
3089  if (NumMallocCalls != MallocCalls.size())
3090  return ChangeStatus::CHANGED;
3091 
3092  return ChangeStatus::UNCHANGED;
3093  }
3094 
3095  /// Collection of all malloc calls in a function.
3096  SmallSetVector<CallBase *, 4> MallocCalls;
3097  /// Collection of potentially removed free calls in a function.
3098  SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3099  /// The total amount of shared memory that has been used for HeapToShared.
3100  unsigned SharedMemoryUsed = 0;
3101 };
3102 
3103 struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3105  AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3106 
3107  /// Statistics are tracked as part of manifest for now.
3108  void trackStatistics() const override {}
3109 
3110  /// See AbstractAttribute::getAsStr()
3111  const std::string getAsStr() const override {
3112  if (!isValidState())
3113  return "<invalid>";
3114  return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3115  : "generic") +
3116  std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3117  : "") +
3118  std::string(" #PRs: ") +
3119  (ReachedKnownParallelRegions.isValidState()
3120  ? std::to_string(ReachedKnownParallelRegions.size())
3121  : "<invalid>") +
3122  ", #Unknown PRs: " +
3123  (ReachedUnknownParallelRegions.isValidState()
3124  ? std::to_string(ReachedUnknownParallelRegions.size())
3125  : "<invalid>") +
3126  ", #Reaching Kernels: " +
3127  (ReachingKernelEntries.isValidState()
3128  ? std::to_string(ReachingKernelEntries.size())
3129  : "<invalid>");
3130  }
3131 
3132  /// Create an abstract attribute biew for the position \p IRP.
3133  static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3134 
3135  /// See AbstractAttribute::getName()
3136  const std::string getName() const override { return "AAKernelInfo"; }
3137 
3138  /// See AbstractAttribute::getIdAddr()
3139  const char *getIdAddr() const override { return &ID; }
3140 
3141  /// This function should return true if the type of the \p AA is AAKernelInfo
3142  static bool classof(const AbstractAttribute *AA) {
3143  return (AA->getIdAddr() == &ID);
3144  }
3145 
3146  static const char ID;
3147 };
3148 
3149 /// The function kernel info abstract attribute, basically, what can we say
3150 /// about a function with regards to the KernelInfoState.
3151 struct AAKernelInfoFunction : AAKernelInfo {
3152  AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3153  : AAKernelInfo(IRP, A) {}
3154 
3155  SmallPtrSet<Instruction *, 4> GuardedInstructions;
3156 
3157  SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3158  return GuardedInstructions;
3159  }
3160 
3161  /// See AbstractAttribute::initialize(...).
3162  void initialize(Attributor &A) override {
3163  // This is a high-level transform that might change the constant arguments
3164  // of the init and dinit calls. We need to tell the Attributor about this
3165  // to avoid other parts using the current constant value for simpliication.
3166  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3167 
3168  Function *Fn = getAnchorScope();
3169 
3170  OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3171  OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3172  OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3173  OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3174 
3175  // For kernels we perform more initialization work, first we find the init
3176  // and deinit calls.
3177  auto StoreCallBase = [](Use &U,
3178  OMPInformationCache::RuntimeFunctionInfo &RFI,
3179  CallBase *&Storage) {
3180  CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3181  assert(CB &&
3182  "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3183  assert(!Storage &&
3184  "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3185  Storage = CB;
3186  return false;
3187  };
3188  InitRFI.foreachUse(
3189  [&](Use &U, Function &) {
3190  StoreCallBase(U, InitRFI, KernelInitCB);
3191  return false;
3192  },
3193  Fn);
3194  DeinitRFI.foreachUse(
3195  [&](Use &U, Function &) {
3196  StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3197  return false;
3198  },
3199  Fn);
3200 
3201  // Ignore kernels without initializers such as global constructors.
3202  if (!KernelInitCB || !KernelDeinitCB)
3203  return;
3204 
3205  // Add itself to the reaching kernel and set IsKernelEntry.
3206  ReachingKernelEntries.insert(Fn);
3207  IsKernelEntry = true;
3208 
3209  // For kernels we might need to initialize/finalize the IsSPMD state and
3210  // we need to register a simplification callback so that the Attributor
3211  // knows the constant arguments to __kmpc_target_init and
3212  // __kmpc_target_deinit might actually change.
3213 
3214  Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
3215  [&](const IRPosition &IRP, const AbstractAttribute *AA,
3216  bool &UsedAssumedInformation) -> Optional<Value *> {
3217  // IRP represents the "use generic state machine" argument of an
3218  // __kmpc_target_init call. We will answer this one with the internal
3219  // state. As long as we are not in an invalid state, we will create a
3220  // custom state machine so the value should be a `i1 false`. If we are
3221  // in an invalid state, we won't change the value that is in the IR.
3222  if (!ReachedKnownParallelRegions.isValidState())
3223  return nullptr;
3224  // If we have disabled state machine rewrites, don't make a custom one.
3226  return nullptr;
3227  if (AA)
3228  A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3229  UsedAssumedInformation = !isAtFixpoint();
3230  auto *FalseVal =
3232  return FalseVal;
3233  };
3234 
3235  Attributor::SimplifictionCallbackTy ModeSimplifyCB =
3236  [&](const IRPosition &IRP, const AbstractAttribute *AA,
3237  bool &UsedAssumedInformation) -> Optional<Value *> {
3238  // IRP represents the "SPMDCompatibilityTracker" argument of an
3239  // __kmpc_target_init or
3240  // __kmpc_target_deinit call. We will answer this one with the internal
3241  // state.
3242  if (!SPMDCompatibilityTracker.isValidState())
3243  return nullptr;
3244  if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3245  if (AA)
3246  A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3247  UsedAssumedInformation = true;
3248  } else {
3249  UsedAssumedInformation = false;
3250  }
3251  auto *Val = ConstantInt::getSigned(
3253  SPMDCompatibilityTracker.isAssumed() ? OMP_TGT_EXEC_MODE_SPMD
3255  return Val;
3256  };
3257 
3258  Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB =
3259  [&](const IRPosition &IRP, const AbstractAttribute *AA,
3260  bool &UsedAssumedInformation) -> Optional<Value *> {
3261  // IRP represents the "RequiresFullRuntime" argument of an
3262  // __kmpc_target_init or __kmpc_target_deinit call. We will answer this
3263  // one with the internal state of the SPMDCompatibilityTracker, so if
3264  // generic then true, if SPMD then false.
3265  if (!SPMDCompatibilityTracker.isValidState())
3266  return nullptr;
3267  if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3268  if (AA)
3269  A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3270  UsedAssumedInformation = true;
3271  } else {
3272  UsedAssumedInformation = false;
3273  }
3274  auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
3275  !SPMDCompatibilityTracker.isAssumed());
3276  return Val;
3277  };
3278 
3279  constexpr const int InitModeArgNo = 1;
3280  constexpr const int DeinitModeArgNo = 1;
3281  constexpr const int InitUseStateMachineArgNo = 2;
3282  constexpr const int InitRequiresFullRuntimeArgNo = 3;
3283  constexpr const int DeinitRequiresFullRuntimeArgNo = 2;
3284  A.registerSimplificationCallback(
3285  IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
3286  StateMachineSimplifyCB);
3287  A.registerSimplificationCallback(
3288  IRPosition::callsite_argument(*KernelInitCB, InitModeArgNo),
3289  ModeSimplifyCB);
3290  A.registerSimplificationCallback(
3291  IRPosition::callsite_argument(*KernelDeinitCB, DeinitModeArgNo),
3292  ModeSimplifyCB);
3293  A.registerSimplificationCallback(
3294  IRPosition::callsite_argument(*KernelInitCB,
3295  InitRequiresFullRuntimeArgNo),
3296  IsGenericModeSimplifyCB);
3297  A.registerSimplificationCallback(
3298  IRPosition::callsite_argument(*KernelDeinitCB,
3299  DeinitRequiresFullRuntimeArgNo),
3300  IsGenericModeSimplifyCB);
3301 
3302  // Check if we know we are in SPMD-mode already.
3303  ConstantInt *ModeArg =
3304  dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
3305  if (ModeArg && (ModeArg->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
3306  SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3307  // This is a generic region but SPMDization is disabled so stop tracking.
3308  else if (DisableOpenMPOptSPMDization)
3309  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3310  }
3311 
3312  /// Sanitize the string \p S such that it is a suitable global symbol name.
3313  static std::string sanitizeForGlobalName(std::string S) {
3314  std::replace_if(
3315  S.begin(), S.end(),
3316  [](const char C) {
3317  return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3318  (C >= '0' && C <= '9') || C == '_');
3319  },
3320  '.');
3321  return S;
3322  }
3323 
3324  /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3325  /// finished now.
3326  ChangeStatus manifest(Attributor &A) override {
3327  // If we are not looking at a kernel with __kmpc_target_init and
3328  // __kmpc_target_deinit call we cannot actually manifest the information.
3329  if (!KernelInitCB || !KernelDeinitCB)
3330  return ChangeStatus::UNCHANGED;
3331 
3332  // If we can we change the execution mode to SPMD-mode otherwise we build a
3333  // custom state machine.
3335  if (!changeToSPMDMode(A, Changed))
3336  return buildCustomStateMachine(A);
3337 
3338  return Changed;
3339  }
3340 
3341  void insertInstructionGuardsHelper(Attributor &A) {
3342  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3343 
3344  auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3345  Instruction *RegionEndI) {
3346  LoopInfo *LI = nullptr;
3347  DominatorTree *DT = nullptr;
3348  MemorySSAUpdater *MSU = nullptr;
3349  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3350 
3351  BasicBlock *ParentBB = RegionStartI->getParent();
3352  Function *Fn = ParentBB->getParent();
3353  Module &M = *Fn->getParent();
3354 
3355  // Create all the blocks and logic.
3356  // ParentBB:
3357  // goto RegionCheckTidBB
3358  // RegionCheckTidBB:
3359  // Tid = __kmpc_hardware_thread_id()
3360  // if (Tid != 0)
3361  // goto RegionBarrierBB
3362  // RegionStartBB:
3363  // <execute instructions guarded>
3364  // goto RegionEndBB
3365  // RegionEndBB:
3366  // <store escaping values to shared mem>
3367  // goto RegionBarrierBB
3368  // RegionBarrierBB:
3369  // __kmpc_simple_barrier_spmd()
3370  // // second barrier is omitted if lacking escaping values.
3371  // <load escaping values from shared mem>
3372  // __kmpc_simple_barrier_spmd()
3373  // goto RegionExitBB
3374  // RegionExitBB:
3375  // <execute rest of instructions>
3376 
3377  BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3378  DT, LI, MSU, "region.guarded.end");
3379  BasicBlock *RegionBarrierBB =
3380  SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3381  MSU, "region.barrier");
3382  BasicBlock *RegionExitBB =
3383  SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3384  DT, LI, MSU, "region.exit");
3385  BasicBlock *RegionStartBB =
3386  SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3387 
3388  assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3389  "Expected a different CFG");
3390 
3391  BasicBlock *RegionCheckTidBB = SplitBlock(
3392  ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3393 
3394  // Register basic blocks with the Attributor.
3395  A.registerManifestAddedBasicBlock(*RegionEndBB);
3396  A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3397  A.registerManifestAddedBasicBlock(*RegionExitBB);
3398  A.registerManifestAddedBasicBlock(*RegionStartBB);
3399  A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
3400 
3401  bool HasBroadcastValues = false;
3402  // Find escaping outputs from the guarded region to outside users and
3403  // broadcast their values to them.
3404  for (Instruction &I : *RegionStartBB) {
3405  SmallPtrSet<Instruction *, 4> OutsideUsers;
3406  for (User *Usr : I.users()) {
3407  Instruction &UsrI = *cast<Instruction>(Usr);
3408  if (UsrI.getParent() != RegionStartBB)
3409  OutsideUsers.insert(&UsrI);
3410  }
3411 
3412  if (OutsideUsers.empty())
3413  continue;
3414 
3415  HasBroadcastValues = true;
3416 
3417  // Emit a global variable in shared memory to store the broadcasted
3418  // value.
3419  auto *SharedMem = new GlobalVariable(
3420  M, I.getType(), /* IsConstant */ false,
3422  sanitizeForGlobalName(
3423  (I.getName() + ".guarded.output.alloc").str()),
3424  nullptr, GlobalValue::NotThreadLocal,
3425  static_cast<unsigned>(AddressSpace::Shared));
3426 
3427  // Emit a store instruction to update the value.
3428  new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
3429 
3430  LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
3431  I.getName() + ".guarded.output.load",
3432  RegionBarrierBB->getTerminator());
3433 
3434  // Emit a load instruction and replace uses of the output value.
3435  for (Instruction *UsrI : OutsideUsers)
3436  UsrI->replaceUsesOfWith(&I, LoadI);
3437  }
3438 
3439  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3440 
3441  // Go to tid check BB in ParentBB.
3442  const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
3443  ParentBB->getTerminator()->eraseFromParent();
3445  InsertPointTy(ParentBB, ParentBB->end()), DL);
3446  OMPInfoCache.OMPBuilder.updateToLocation(Loc);
3447  uint32_t SrcLocStrSize;
3448  auto *SrcLocStr =
3449  OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3450  Value *Ident =
3451  OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3452  BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
3453 
3454  // Add check for Tid in RegionCheckTidBB
3455  RegionCheckTidBB->getTerminator()->eraseFromParent();
3456  OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
3457  InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
3458  OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
3459  FunctionCallee HardwareTidFn =
3460  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3461  M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3462  CallInst *Tid =
3463  OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
3464  Tid->setDebugLoc(DL);
3465  OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
3466  Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
3467  OMPInfoCache.OMPBuilder.Builder
3468  .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
3469  ->setDebugLoc(DL);
3470 
3471  // First barrier for synchronization, ensures main thread has updated
3472  // values.
3473  FunctionCallee BarrierFn =
3474  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3475  M, OMPRTL___kmpc_barrier_simple_spmd);
3476  OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
3477  RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
3478  CallInst *Barrier =
3479  OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
3480  Barrier->setDebugLoc(DL);
3481  OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3482 
3483  // Second barrier ensures workers have read broadcast values.
3484  if (HasBroadcastValues) {
3485  CallInst *Barrier = CallInst::Create(BarrierFn, {Ident, Tid}, "",
3486  RegionBarrierBB->getTerminator());
3487  Barrier->setDebugLoc(DL);
3488  OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3489  }
3490  };
3491 
3492  auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3494  for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3495  BasicBlock *BB = GuardedI->getParent();
3496  if (!Visited.insert(BB).second)
3497  continue;
3498 
3500  Instruction *LastEffect = nullptr;
3501  BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
3502  while (++IP != IPEnd) {
3503  if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
3504  continue;
3505  Instruction *I = &*IP;
3506  if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
3507  continue;
3508  if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
3509  LastEffect = nullptr;
3510  continue;
3511  }
3512  if (LastEffect)
3513  Reorders.push_back({I, LastEffect});
3514  LastEffect = &*IP;
3515  }
3516  for (auto &Reorder : Reorders)
3517  Reorder.first->moveBefore(Reorder.second);
3518  }
3519 
3521 
3522  for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3523  BasicBlock *BB = GuardedI->getParent();
3524  auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
3525  IRPosition::function(*GuardedI->getFunction()), nullptr,
3527  assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
3528  auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
3529  // Continue if instruction is already guarded.
3530  if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
3531  continue;
3532 
3533  Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
3534  for (Instruction &I : *BB) {
3535  // If instruction I needs to be guarded update the guarded region
3536  // bounds.
3537  if (SPMDCompatibilityTracker.contains(&I)) {
3538  CalleeAAFunction.getGuardedInstructions().insert(&I);
3539  if (GuardedRegionStart)
3540  GuardedRegionEnd = &I;
3541  else
3542  GuardedRegionStart = GuardedRegionEnd = &I;
3543 
3544  continue;
3545  }
3546 
3547  // Instruction I does not need guarding, store
3548  // any region found and reset bounds.
3549  if (GuardedRegionStart) {
3550  GuardedRegions.push_back(
3551  std::make_pair(GuardedRegionStart, GuardedRegionEnd));
3552  GuardedRegionStart = nullptr;
3553  GuardedRegionEnd = nullptr;
3554  }
3555  }
3556  }
3557 
3558  for (auto &GR : GuardedRegions)
3559  CreateGuardedRegion(GR.first, GR.second);
3560  }
3561 
3562  void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
3563  // Only allow 1 thread per workgroup to continue executing the user code.
3564  //
3565  // InitCB = __kmpc_target_init(...)
3566  // ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
3567  // if (ThreadIdInBlock != 0) return;
3568  // UserCode:
3569  // // user code
3570  //
3571  auto &Ctx = getAnchorValue().getContext();
3572  Function *Kernel = getAssociatedFunction();
3573  assert(Kernel && "Expected an associated function!");
3574 
3575  // Create block for user code to branch to from initial block.
3576  BasicBlock *InitBB = KernelInitCB->getParent();
3577  BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
3578  KernelInitCB->getNextNode(), "main.thread.user_code");
3579  BasicBlock *ReturnBB =
3580  BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
3581 
3582  // Register blocks with attributor:
3583  A.registerManifestAddedBasicBlock(*InitBB);
3584  A.registerManifestAddedBasicBlock(*UserCodeBB);
3585  A.registerManifestAddedBasicBlock(*ReturnBB);
3586 
3587  // Debug location:
3588  const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3589  ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
3590  InitBB->getTerminator()->eraseFromParent();
3591 
3592  // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
3593  Module &M = *Kernel->getParent();
3594  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3595  FunctionCallee ThreadIdInBlockFn =
3596  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3597  M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3598 
3599  // Get thread ID in block.
3600  CallInst *ThreadIdInBlock =
3601  CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
3602  OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
3603  ThreadIdInBlock->setDebugLoc(DLoc);
3604 
3605  // Eliminate all threads in the block with ID not equal to 0:
3606  Instruction *IsMainThread =
3607  ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
3608  ConstantInt::get(ThreadIdInBlock->getType(), 0),
3609  "thread.is_main", InitBB);
3610  IsMainThread->setDebugLoc(DLoc);
3611  BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB);
3612  }
3613 
3614  bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
3615  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3616 
3617  if (!SPMDCompatibilityTracker.isAssumed()) {
3618  for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
3619  if (!NonCompatibleI)
3620  continue;
3621 
3622  // Skip diagnostics on calls to known OpenMP runtime functions for now.
3623  if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
3624  if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
3625  continue;
3626 
3627  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3628  ORA << "Value has potential side effects preventing SPMD-mode "
3629  "execution";
3630  if (isa<CallBase>(NonCompatibleI)) {
3631  ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
3632  "the called function to override";
3633  }
3634  return ORA << ".";
3635  };
3636  A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
3637  Remark);
3638 
3639  LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
3640  << *NonCompatibleI << "\n");
3641  }
3642 
3643  return false;
3644  }
3645 
3646  // Get the actual kernel, could be the caller of the anchor scope if we have
3647  // a debug wrapper.
3648  Function *Kernel = getAnchorScope();
3649  if (Kernel->hasLocalLinkage()) {
3650  assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
3651  auto *CB = cast<CallBase>(Kernel->user_back());
3652  Kernel = CB->getCaller();
3653  }
3654  assert(OMPInfoCache.Kernels.count(Kernel) && "Expected kernel function!");
3655 
3656  // Check if the kernel is already in SPMD mode, if so, return success.
3658  (Kernel->getName() + "_exec_mode").str());
3659  assert(ExecMode && "Kernel without exec mode?");
3660  assert(ExecMode->getInitializer() && "ExecMode doesn't have initializer!");
3661 
3662  // Set the global exec mode flag to indicate SPMD-Generic mode.
3663  assert(isa<ConstantInt>(ExecMode->getInitializer()) &&
3664  "ExecMode is not an integer!");
3665  const int8_t ExecModeVal =
3666  cast<ConstantInt>(ExecMode->getInitializer())->getSExtValue();
3667  if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
3668  return true;
3669 
3670  // We will now unconditionally modify the IR, indicate a change.
3671  Changed = ChangeStatus::CHANGED;
3672 
3673  // Do not use instruction guards when no parallel is present inside
3674  // the target region.
3675  if (mayContainParallelRegion())
3676  insertInstructionGuardsHelper(A);
3677  else
3678  forceSingleThreadPerWorkgroupHelper(A);
3679 
3680  // Adjust the global exec mode flag that tells the runtime what mode this
3681  // kernel is executed in.
3682  assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
3683  "Initially non-SPMD kernel has SPMD exec mode!");
3684  ExecMode->setInitializer(
3685  ConstantInt::get(ExecMode->getInitializer()->getType(),
3686  ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
3687 
3688  // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
3689  const int InitModeArgNo = 1;
3690  const int DeinitModeArgNo = 1;
3691  const int InitUseStateMachineArgNo = 2;
3692  const int InitRequiresFullRuntimeArgNo = 3;
3693  const int DeinitRequiresFullRuntimeArgNo = 2;
3694 
3695  auto &Ctx = getAnchorValue().getContext();
3696  A.changeUseAfterManifest(
3697  KernelInitCB->getArgOperandUse(InitModeArgNo),
3700  A.changeUseAfterManifest(
3701  KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
3702  *ConstantInt::getBool(Ctx, false));
3703  A.changeUseAfterManifest(
3704  KernelDeinitCB->getArgOperandUse(DeinitModeArgNo),
3707  A.changeUseAfterManifest(
3708  KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo),
3709  *ConstantInt::getBool(Ctx, false));
3710  A.changeUseAfterManifest(
3711  KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo),
3712  *ConstantInt::getBool(Ctx, false));
3713 
3714  ++NumOpenMPTargetRegionKernelsSPMD;
3715 
3716  auto Remark = [&](OptimizationRemark OR) {
3717  return OR << "Transformed generic-mode kernel to SPMD-mode.";
3718  };
3719  A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
3720  return true;
3721  };
3722 
3723  ChangeStatus buildCustomStateMachine(Attributor &A) {
3724  // If we have disabled state machine rewrites, don't make a custom one
3726  return ChangeStatus::UNCHANGED;
3727 
3728  // Don't rewrite the state machine if we are not in a valid state.
3729  if (!ReachedKnownParallelRegions.isValidState())
3730  return ChangeStatus::UNCHANGED;
3731 
3732  const int InitModeArgNo = 1;
3733  const int InitUseStateMachineArgNo = 2;
3734 
3735  // Check if the current configuration is non-SPMD and generic state machine.
3736  // If we already have SPMD mode or a custom state machine we do not need to
3737  // go any further. If it is anything but a constant something is weird and
3738  // we give up.
3739  ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
3740  KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
3741  ConstantInt *Mode =
3742  dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
3743 
3744  // If we are stuck with generic mode, try to create a custom device (=GPU)
3745  // state machine which is specialized for the parallel regions that are
3746  // reachable by the kernel.
3747  if (!UseStateMachine || UseStateMachine->isZero() || !Mode ||
3748  (Mode->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
3749  return ChangeStatus::UNCHANGED;
3750 
3751  // If not SPMD mode, indicate we use a custom state machine now.
3752  auto &Ctx = getAnchorValue().getContext();
3753  auto *FalseVal = ConstantInt::getBool(Ctx, false);
3754  A.changeUseAfterManifest(
3755  KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
3756 
3757  // If we don't actually need a state machine we are done here. This can
3758  // happen if there simply are no parallel regions. In the resulting kernel
3759  // all worker threads will simply exit right away, leaving the main thread
3760  // to do the work alone.
3761  if (!mayContainParallelRegion()) {
3762  ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
3763 
3764  auto Remark = [&](OptimizationRemark OR) {
3765  return OR << "Removing unused state machine from generic-mode kernel.";
3766  };
3767  A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
3768 
3769  return ChangeStatus::CHANGED;
3770  }
3771 
3772  // Keep track in the statistics of our new shiny custom state machine.
3773  if (ReachedUnknownParallelRegions.empty()) {
3774  ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
3775 
3776  auto Remark = [&](OptimizationRemark OR) {
3777  return OR << "Rewriting generic-mode kernel with a customized state "
3778  "machine.";
3779  };
3780  A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
3781  } else {
3782  ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
3783 
3784  auto Remark = [&](OptimizationRemarkAnalysis OR) {
3785  return OR << "Generic-mode kernel is executed with a customized state "
3786  "machine that requires a fallback.";
3787  };
3788  A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
3789 
3790  // Tell the user why we ended up with a fallback.
3791  for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
3792  if (!UnknownParallelRegionCB)
3793  continue;
3794  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3795  return ORA << "Call may contain unknown parallel regions. Use "
3796  << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
3797  "override.";
3798  };
3799  A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
3800  "OMP133", Remark);
3801  }
3802  }
3803 
3804  // Create all the blocks:
3805  //
3806  // InitCB = __kmpc_target_init(...)
3807  // BlockHwSize =
3808  // __kmpc_get_hardware_num_threads_in_block();
3809  // WarpSize = __kmpc_get_warp_size();
3810  // BlockSize = BlockHwSize - WarpSize;
3811  // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
3812  // if (IsWorker) {
3813  // if (InitCB >= BlockSize) return;
3814  // SMBeginBB: __kmpc_barrier_simple_generic(...);
3815  // void *WorkFn;
3816  // bool Active = __kmpc_kernel_parallel(&WorkFn);
3817  // if (!WorkFn) return;
3818  // SMIsActiveCheckBB: if (Active) {
3819  // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
3820  // ParFn0(...);
3821  // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
3822  // ParFn1(...);
3823  // ...
3824  // SMIfCascadeCurrentBB: else
3825  // ((WorkFnTy*)WorkFn)(...);
3826  // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
3827  // }
3828  // SMDoneBB: __kmpc_barrier_simple_generic(...);
3829  // goto SMBeginBB;
3830  // }
3831  // UserCodeEntryBB: // user code
3832  // __kmpc_target_deinit(...)
3833  //
3834  Function *Kernel = getAssociatedFunction();
3835  assert(Kernel && "Expected an associated function!");
3836 
3837  BasicBlock *InitBB = KernelInitCB->getParent();
3838  BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
3839  KernelInitCB->getNextNode(), "thread.user_code.check");
3840  BasicBlock *IsWorkerCheckBB =
3841  BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
3842  BasicBlock *StateMachineBeginBB = BasicBlock::Create(
3843  Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
3844  BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
3845  Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
3846  BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
3847  Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
3848  BasicBlock *StateMachineIfCascadeCurrentBB =
3849  BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3850  Kernel, UserCodeEntryBB);
3851  BasicBlock *StateMachineEndParallelBB =
3852  BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
3853  Kernel, UserCodeEntryBB);
3854  BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
3855  Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
3856  A.registerManifestAddedBasicBlock(*InitBB);
3857  A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
3858  A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
3859  A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
3860  A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
3861  A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
3862  A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
3863  A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
3864  A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
3865 
3866  const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3867  ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
3868  InitBB->getTerminator()->eraseFromParent();
3869 
3870  Instruction *IsWorker =
3871  ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
3872  ConstantInt::get(KernelInitCB->getType(), -1),
3873  "thread.is_worker", InitBB);
3874  IsWorker->setDebugLoc(DLoc);
3875  BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
3876 
3877  Module &M = *Kernel->getParent();
3878  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3879  FunctionCallee BlockHwSizeFn =
3880  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3881  M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
3882  FunctionCallee WarpSizeFn =
3883  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3884  M, OMPRTL___kmpc_get_warp_size);
3885  CallInst *BlockHwSize =
3886  CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
3887  OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
3888  BlockHwSize->setDebugLoc(DLoc);
3889  CallInst *WarpSize =
3890  CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
3891  OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
3892  WarpSize->setDebugLoc(DLoc);
3893  Instruction *BlockSize = BinaryOperator::CreateSub(
3894  BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
3895  BlockSize->setDebugLoc(DLoc);
3896  Instruction *IsMainOrWorker = ICmpInst::Create(
3897  ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
3898  "thread.is_main_or_worker", IsWorkerCheckBB);
3899  IsMainOrWorker->setDebugLoc(DLoc);
3900  BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
3901  IsMainOrWorker, IsWorkerCheckBB);
3902 
3903  // Create local storage for the work function pointer.
3904  const DataLayout &DL = M.getDataLayout();
3905  Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
3906  Instruction *WorkFnAI =
3907  new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
3908  "worker.work_fn.addr", &Kernel->getEntryBlock().front());
3909  WorkFnAI->setDebugLoc(DLoc);
3910 
3911  OMPInfoCache.OMPBuilder.updateToLocation(
3913  IRBuilder<>::InsertPoint(StateMachineBeginBB,
3914  StateMachineBeginBB->end()),
3915  DLoc));
3916 
3917  Value *Ident = KernelInitCB->getArgOperand(0);
3918  Value *GTid = KernelInitCB;
3919 
3920  FunctionCallee BarrierFn =
3921  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3922  M, OMPRTL___kmpc_barrier_simple_generic);
3923  CallInst *Barrier =
3924  CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
3925  OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3926  Barrier->setDebugLoc(DLoc);
3927 
3928  if (WorkFnAI->getType()->getPointerAddressSpace() !=
3929  (unsigned int)AddressSpace::Generic) {
3930  WorkFnAI = new AddrSpaceCastInst(
3931  WorkFnAI,
3933  cast<PointerType>(WorkFnAI->getType()),
3934  (unsigned int)AddressSpace::Generic),
3935  WorkFnAI->getName() + ".generic", StateMachineBeginBB);
3936  WorkFnAI->setDebugLoc(DLoc);
3937  }
3938 
3939  FunctionCallee KernelParallelFn =
3940  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3941  M, OMPRTL___kmpc_kernel_parallel);
3942  CallInst *IsActiveWorker = CallInst::Create(
3943  KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
3944  OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
3945  IsActiveWorker->setDebugLoc(DLoc);
3946  Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
3947  StateMachineBeginBB);
3948  WorkFn->setDebugLoc(DLoc);
3949 
3950  FunctionType *ParallelRegionFnTy = FunctionType::get(
3952  false);
3954  WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
3955  StateMachineBeginBB);
3956 
3957  Instruction *IsDone =
3958  ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
3959  Constant::getNullValue(VoidPtrTy), "worker.is_done",
3960  StateMachineBeginBB);
3961  IsDone->setDebugLoc(DLoc);
3962  BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
3963  IsDone, StateMachineBeginBB)
3964  ->setDebugLoc(DLoc);
3965 
3966  BranchInst::Create(StateMachineIfCascadeCurrentBB,
3967  StateMachineDoneBarrierBB, IsActiveWorker,
3968  StateMachineIsActiveCheckBB)
3969  ->setDebugLoc(DLoc);
3970 
3971  Value *ZeroArg =
3972  Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
3973 
3974  // Now that we have most of the CFG skeleton it is time for the if-cascade
3975  // that checks the function pointer we got from the runtime against the
3976  // parallel regions we expect, if there are any.
3977  for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
3978  auto *ParallelRegion = ReachedKnownParallelRegions[I];
3979  BasicBlock *PRExecuteBB = BasicBlock::Create(
3980  Ctx, "worker_state_machine.parallel_region.execute", Kernel,
3981  StateMachineEndParallelBB);
3982  CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
3983  ->setDebugLoc(DLoc);
3984  BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
3985  ->setDebugLoc(DLoc);
3986 
3987  BasicBlock *PRNextBB =
3988  BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3989  Kernel, StateMachineEndParallelBB);
3990 
3991  // Check if we need to compare the pointer at all or if we can just
3992  // call the parallel region function.
3993  Value *IsPR;
3994  if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
3995  Instruction *CmpI = ICmpInst::Create(
3996  ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
3997  "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
3998  CmpI->setDebugLoc(DLoc);
3999  IsPR = CmpI;
4000  } else {
4001  IsPR = ConstantInt::getTrue(Ctx);
4002  }
4003 
4004  BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
4005  StateMachineIfCascadeCurrentBB)
4006  ->setDebugLoc(DLoc);
4007  StateMachineIfCascadeCurrentBB = PRNextBB;
4008  }
4009 
4010  // At the end of the if-cascade we place the indirect function pointer call
4011  // in case we might need it, that is if there can be parallel regions we
4012  // have not handled in the if-cascade above.
4013  if (!ReachedUnknownParallelRegions.empty()) {
4014  StateMachineIfCascadeCurrentBB->setName(
4015  "worker_state_machine.parallel_region.fallback.execute");
4016  CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
4017  StateMachineIfCascadeCurrentBB)
4018  ->setDebugLoc(DLoc);
4019  }
4020  BranchInst::Create(StateMachineEndParallelBB,
4021  StateMachineIfCascadeCurrentBB)
4022  ->setDebugLoc(DLoc);
4023 
4024  FunctionCallee EndParallelFn =
4025  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4026  M, OMPRTL___kmpc_kernel_end_parallel);
4027  CallInst *EndParallel =
4028  CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4029  OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4030  EndParallel->setDebugLoc(DLoc);
4031  BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4032  ->setDebugLoc(DLoc);
4033 
4034  CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4035  ->setDebugLoc(DLoc);
4036  BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4037  ->setDebugLoc(DLoc);
4038 
4039  return ChangeStatus::CHANGED;
4040  }
4041 
4042  /// Fixpoint iteration update function. Will be called every time a dependence
4043  /// changed its state (and in the beginning).
4044  ChangeStatus updateImpl(Attributor &A) override {
4045  KernelInfoState StateBefore = getState();
4046 
4047  // Callback to check a read/write instruction.
4048  auto CheckRWInst = [&](Instruction &I) {
4049  // We handle calls later.
4050  if (isa<CallBase>(I))
4051  return true;
4052  // We only care about write effects.
4053  if (!I.mayWriteToMemory())
4054  return true;
4055  if (auto *SI = dyn_cast<StoreInst>(&I)) {
4057  getUnderlyingObjects(SI->getPointerOperand(), Objects);
4058  if (llvm::all_of(Objects,
4059  [](const Value *Obj) { return isa<AllocaInst>(Obj); }))
4060  return true;
4061  // Check for AAHeapToStack moved objects which must not be guarded.
4062  auto &HS = A.getAAFor<AAHeapToStack>(
4063  *this, IRPosition::function(*I.getFunction()),
4065  if (llvm::all_of(Objects, [&HS](const Value *Obj) {
4066  auto *CB = dyn_cast<CallBase>(Obj);
4067  if (!CB)
4068  return false;
4069  return HS.isAssumedHeapToStack(*CB);
4070  })) {
4071  return true;
4072  }
4073  }
4074 
4075  // Insert instruction that needs guarding.
4076  SPMDCompatibilityTracker.insert(&I);
4077  return true;
4078  };
4079 
4080  bool UsedAssumedInformationInCheckRWInst = false;
4081  if (!SPMDCompatibilityTracker.isAtFixpoint())
4082  if (!A.checkForAllReadWriteInstructions(
4083  CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4084  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4085 
4086  bool UsedAssumedInformationFromReachingKernels = false;
4087  if (!IsKernelEntry) {
4088  updateParallelLevels(A);
4089 
4090  bool AllReachingKernelsKnown = true;
4091  updateReachingKernelEntries(A, AllReachingKernelsKnown);
4092  UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4093 
4094  if (!ParallelLevels.isValidState())
4095  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4096  else if (!ReachingKernelEntries.isValidState())
4097  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4098  else if (!SPMDCompatibilityTracker.empty()) {
4099  // Check if all reaching kernels agree on the mode as we can otherwise
4100  // not guard instructions. We might not be sure about the mode so we
4101  // we cannot fix the internal spmd-zation state either.
4102  int SPMD = 0, Generic = 0;
4103  for (auto *Kernel : ReachingKernelEntries) {
4104  auto &CBAA = A.getAAFor<AAKernelInfo>(
4106  if (CBAA.SPMDCompatibilityTracker.isValidState() &&
4107  CBAA.SPMDCompatibilityTracker.isAssumed())
4108  ++SPMD;
4109  else
4110  ++Generic;
4111  if (!CBAA.SPMDCompatibilityTracker.isAtFixpoint())
4112  UsedAssumedInformationFromReachingKernels = true;
4113  }
4114  if (SPMD != 0 && Generic != 0)
4115  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4116  }
4117  }
4118 
4119  // Callback to check a call instruction.
4120  bool AllParallelRegionStatesWereFixed = true;
4121  bool AllSPMDStatesWereFixed = true;
4122  auto CheckCallInst = [&](Instruction &I) {
4123  auto &CB = cast<CallBase>(I);
4124  auto &CBAA = A.getAAFor<AAKernelInfo>(
4126  getState() ^= CBAA.getState();
4127  AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
4128  AllParallelRegionStatesWereFixed &=
4129  CBAA.ReachedKnownParallelRegions.isAtFixpoint();
4130  AllParallelRegionStatesWereFixed &=
4131  CBAA.ReachedUnknownParallelRegions.isAtFixpoint();
4132  return true;
4133  };
4134 
4135  bool UsedAssumedInformationInCheckCallInst = false;
4136  if (!A.checkForAllCallLikeInstructions(
4137  CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4138  LLVM_DEBUG(dbgs() << TAG
4139  << "Failed to visit all call-like instructions!\n";);
4140  return indicatePessimisticFixpoint();
4141  }
4142 
4143  // If we haven't used any assumed information for the reached parallel
4144  // region states we can fix it.
4145  if (!UsedAssumedInformationInCheckCallInst &&
4146  AllParallelRegionStatesWereFixed) {
4147  ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4148  ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4149  }
4150 
4151  // If we haven't used any assumed information for the SPMD state we can fix
4152  // it.
4153  if (!UsedAssumedInformationInCheckRWInst &&
4154  !UsedAssumedInformationInCheckCallInst &&
4155  !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4156  SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4157 
4158  return StateBefore == getState() ? ChangeStatus::UNCHANGED
4160  }
4161 
4162 private:
4163  /// Update info regarding reaching kernels.
4164  void updateReachingKernelEntries(Attributor &A,
4165  bool &AllReachingKernelsKnown) {
4166  auto PredCallSite = [&](AbstractCallSite ACS) {
4167  Function *Caller = ACS.getInstruction()->getFunction();
4168 
4169  assert(Caller && "Caller is nullptr");
4170 
4171  auto &CAA = A.getOrCreateAAFor<AAKernelInfo>(
4172  IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4173  if (CAA.ReachingKernelEntries.isValidState()) {
4174  ReachingKernelEntries ^= CAA.ReachingKernelEntries;
4175  return true;
4176  }
4177 
4178  // We lost track of the caller of the associated function, any kernel
4179  // could reach now.
4180  ReachingKernelEntries.indicatePessimisticFixpoint();
4181 
4182  return true;
4183  };
4184 
4185  if (!A.checkForAllCallSites(PredCallSite, *this,
4186  true /* RequireAllCallSites */,
4187  AllReachingKernelsKnown))
4188  ReachingKernelEntries.indicatePessimisticFixpoint();
4189  }
4190 
4191  /// Update info regarding parallel levels.
4192  void updateParallelLevels(Attributor &A) {
4193  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4194  OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4195  OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4196 
4197  auto PredCallSite = [&](AbstractCallSite ACS) {
4198  Function *Caller = ACS.getInstruction()->getFunction();
4199 
4200  assert(Caller && "Caller is nullptr");
4201 
4202  auto &CAA =
4203  A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4204  if (CAA.ParallelLevels.isValidState()) {
4205  // Any function that is called by `__kmpc_parallel_51` will not be
4206  // folded as the parallel level in the function is updated. In order to
4207  // get it right, all the analysis would depend on the implentation. That
4208  // said, if in the future any change to the implementation, the analysis
4209  // could be wrong. As a consequence, we are just conservative here.
4210  if (Caller == Parallel51RFI.Declaration) {
4211  ParallelLevels.indicatePessimisticFixpoint();
4212  return true;
4213  }
4214 
4215  ParallelLevels ^= CAA.ParallelLevels;
4216 
4217  return true;
4218  }
4219 
4220  // We lost track of the caller of the associated function, any kernel
4221  // could reach now.
4222  ParallelLevels.indicatePessimisticFixpoint();
4223 
4224  return true;
4225  };
4226 
4227  bool AllCallSitesKnown = true;
4228  if (!A.checkForAllCallSites(PredCallSite, *this,
4229  true /* RequireAllCallSites */,
4230  AllCallSitesKnown))
4231  ParallelLevels.indicatePessimisticFixpoint();
4232  }
4233 };
4234 
4235 /// The call site kernel info abstract attribute, basically, what can we say
4236 /// about a call site with regards to the KernelInfoState. For now this simply
4237 /// forwards the information from the callee.
4238 struct AAKernelInfoCallSite : AAKernelInfo {
4239  AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4240  : AAKernelInfo(IRP, A) {}
4241 
4242  /// See AbstractAttribute::initialize(...).
4243  void initialize(Attributor &A) override {
4245 
4246  CallBase &CB = cast<CallBase>(getAssociatedValue());
4247  Function *Callee = getAssociatedFunction();
4248 
4249  auto &AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4251 
4252  // Check for SPMD-mode assumptions.
4253  if (AssumptionAA.hasAssumption("ompx_spmd_amenable")) {
4254  SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4255  indicateOptimisticFixpoint();
4256  }
4257 
4258  // First weed out calls we do not care about, that is readonly/readnone
4259  // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4260  // parallel region or anything else we are looking for.
4261  if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4262  indicateOptimisticFixpoint();
4263  return;
4264  }
4265 
4266  // Next we check if we know the callee. If it is a known OpenMP function
4267  // we will handle them explicitly in the switch below. If it is not, we
4268  // will use an AAKernelInfo object on the callee to gather information and
4269  // merge that into the current state. The latter happens in the updateImpl.
4270  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4271  const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4272  if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4273  // Unknown caller or declarations are not analyzable, we give up.
4274  if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4275 
4276  // Unknown callees might contain parallel regions, except if they have
4277  // an appropriate assumption attached.
4278  if (!(AssumptionAA.hasAssumption("omp_no_openmp") ||
4279  AssumptionAA.hasAssumption("omp_no_parallelism")))
4280  ReachedUnknownParallelRegions.insert(&CB);
4281 
4282  // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4283  // idea we can run something unknown in SPMD-mode.
4284  if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4285  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4286  SPMDCompatibilityTracker.insert(&CB);
4287  }
4288 
4289  // We have updated the state for this unknown call properly, there won't
4290  // be any change so we indicate a fixpoint.
4291  indicateOptimisticFixpoint();
4292  }
4293  // If the callee is known and can be used in IPO, we will update the state
4294  // based on the callee state in updateImpl.
4295  return;
4296  }
4297 
4298  const unsigned int WrapperFunctionArgNo = 6;
4299  RuntimeFunction RF = It->getSecond();
4300  switch (RF) {
4301  // All the functions we know are compatible with SPMD mode.
4302  case OMPRTL___kmpc_is_spmd_exec_mode:
4303  case OMPRTL___kmpc_distribute_static_fini:
4304  case OMPRTL___kmpc_for_static_fini:
4305  case OMPRTL___kmpc_global_thread_num:
4306  case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4307  case OMPRTL___kmpc_get_hardware_num_blocks:
4308  case OMPRTL___kmpc_single:
4309  case OMPRTL___kmpc_end_single:
4310  case OMPRTL___kmpc_master:
4311  case OMPRTL___kmpc_end_master:
4312  case OMPRTL___kmpc_barrier:
4313  case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4314  case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4315  case OMPRTL___kmpc_nvptx_end_reduce_nowait:
4316  break;
4317  case OMPRTL___kmpc_distribute_static_init_4:
4318  case OMPRTL___kmpc_distribute_static_init_4u:
4319  case OMPRTL___kmpc_distribute_static_init_8:
4320  case OMPRTL___kmpc_distribute_static_init_8u:
4321  case OMPRTL___kmpc_for_static_init_4:
4322  case OMPRTL___kmpc_for_static_init_4u:
4323  case OMPRTL___kmpc_for_static_init_8:
4324  case OMPRTL___kmpc_for_static_init_8u: {
4325  // Check the schedule and allow static schedule in SPMD mode.
4326  unsigned ScheduleArgOpNo = 2;
4327  auto *ScheduleTypeCI =
4328  dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4329  unsigned ScheduleTypeVal =
4330  ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4331  switch (OMPScheduleType(ScheduleTypeVal)) {
4336  break;
4337  default:
4338  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4339  SPMDCompatibilityTracker.insert(&CB);
4340  break;
4341  };
4342  } break;
4343  case OMPRTL___kmpc_target_init:
4344  KernelInitCB = &CB;
4345  break;
4346  case OMPRTL___kmpc_target_deinit:
4347  KernelDeinitCB = &CB;
4348  break;
4349  case OMPRTL___kmpc_parallel_51:
4350  if (auto *ParallelRegion = dyn_cast<Function>(
4351  CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
4352  ReachedKnownParallelRegions.insert(ParallelRegion);
4353  break;
4354  }
4355  // The condition above should usually get the parallel region function
4356  // pointer and record it. In the off chance it doesn't we assume the
4357  // worst.
4358  ReachedUnknownParallelRegions.insert(&CB);
4359  break;
4360  case OMPRTL___kmpc_omp_task:
4361  // We do not look into tasks right now, just give up.
4362  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4363  SPMDCompatibilityTracker.insert(&CB);
4364  ReachedUnknownParallelRegions.insert(&CB);
4365  break;
4366  case OMPRTL___kmpc_alloc_shared:
4367  case OMPRTL___kmpc_free_shared:
4368  // Return without setting a fixpoint, to be resolved in updateImpl.
4369  return;
4370  default:
4371  // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
4372  // generally. However, they do not hide parallel regions.
4373  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4374  SPMDCompatibilityTracker.insert(&CB);
4375  break;
4376  }
4377  // All other OpenMP runtime calls will not reach parallel regions so they
4378  // can be safely ignored for now. Since it is a known OpenMP runtime call we
4379  // have now modeled all effects and there is no need for any update.
4380  indicateOptimisticFixpoint();
4381  }
4382 
4383  ChangeStatus updateImpl(Attributor &A) override {
4384  // TODO: Once we have call site specific value information we can provide
4385  // call site specific liveness information and then it makes
4386  // sense to specialize attributes for call sites arguments instead of
4387  // redirecting requests to the callee argument.
4388  Function *F = getAssociatedFunction();
4389 
4390  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4391  const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
4392 
4393  // If F is not a runtime function, propagate the AAKernelInfo of the callee.
4394  if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4395  const IRPosition &FnPos = IRPosition::function(*F);
4396  auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
4397  if (getState() == FnAA.getState())
4398  return ChangeStatus::UNCHANGED;
4399  getState() = FnAA.getState();
4400  return ChangeStatus::CHANGED;
4401  }
4402 
4403  // F is a runtime function that allocates or frees memory, check
4404  // AAHeapToStack and AAHeapToShared.
4405  KernelInfoState StateBefore = getState();
4406  assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
4407  It->getSecond() == OMPRTL___kmpc_free_shared) &&
4408  "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
4409 
4410  CallBase &CB = cast<CallBase>(getAssociatedValue());
4411 
4412  auto &HeapToStackAA = A.getAAFor<AAHeapToStack>(
4414  auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>(
4416 
4417  RuntimeFunction RF = It->getSecond();
4418 
4419  switch (RF) {
4420  // If neither HeapToStack nor HeapToShared assume the call is removed,
4421  // assume SPMD incompatibility.
4422  case OMPRTL___kmpc_alloc_shared:
4423  if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
4424  !HeapToSharedAA.isAssumedHeapToShared(CB))
4425  SPMDCompatibilityTracker.insert(&CB);
4426  break;
4427  case OMPRTL___kmpc_free_shared:
4428  if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
4429  !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
4430  SPMDCompatibilityTracker.insert(&CB);
4431  break;
4432  default:
4433  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4434  SPMDCompatibilityTracker.insert(&CB);
4435  }
4436 
4437  return StateBefore == getState() ? ChangeStatus::UNCHANGED
4439  }
4440 };
4441 
4442 struct AAFoldRuntimeCall
4443  : public StateWrapper<BooleanState, AbstractAttribute> {
4445 
4446  AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
4447 
4448  /// Statistics are tracked as part of manifest for now.
4449  void trackStatistics() const override {}
4450 
4451  /// Create an abstract attribute biew for the position \p IRP.
4452  static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
4453  Attributor &A);
4454 
4455  /// See AbstractAttribute::getName()
4456  const std::string getName() const override { return "AAFoldRuntimeCall"; }
4457 
4458  /// See AbstractAttribute::getIdAddr()
4459  const char *getIdAddr() const override { return &ID; }
4460 
4461  /// This function should return true if the type of the \p AA is
4462  /// AAFoldRuntimeCall
4463  static bool classof(const AbstractAttribute *AA) {
4464  return (AA->getIdAddr() == &ID);
4465  }
4466 
4467  static const char ID;
4468 };
4469 
4470 struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
4471  AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
4472  : AAFoldRuntimeCall(IRP, A) {}
4473 
4474  /// See AbstractAttribute::getAsStr()
4475  const std::string getAsStr() const override {
4476  if (!isValidState())
4477  return "<invalid>";
4478 
4479  std::string Str("simplified value: ");
4480 
4481  if (!SimplifiedValue)
4482  return Str + std::string("none");
4483 
4484  if (!SimplifiedValue.value())
4485  return Str + std::string("nullptr");
4486 
4487  if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.value()))
4488  return Str + std::to_string(CI->getSExtValue());
4489 
4490  return Str + std::string("unknown");
4491  }
4492 
4493  void initialize(Attributor &A) override {
4495  indicatePessimisticFixpoint();
4496 
4497  Function *Callee = getAssociatedFunction();
4498 
4499  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4500  const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4501  assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
4502  "Expected a known OpenMP runtime function");
4503 
4504  RFKind = It->getSecond();
4505 
4506  CallBase &CB = cast<CallBase>(getAssociatedValue());
4507  A.registerSimplificationCallback(
4509  [&](const IRPosition &IRP, const AbstractAttribute *AA,
4510  bool &UsedAssumedInformation) -> Optional<Value *> {
4511  assert((isValidState() ||
4512  (SimplifiedValue && SimplifiedValue.value() == nullptr)) &&
4513  "Unexpected invalid state!");
4514 
4515  if (!isAtFixpoint()) {
4516  UsedAssumedInformation = true;
4517  if (AA)
4518  A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
4519  }
4520  return SimplifiedValue;
4521  });
4522  }
4523 
4524  ChangeStatus updateImpl(Attributor &A) override {
4526  switch (RFKind) {
4527  case OMPRTL___kmpc_is_spmd_exec_mode:
4528  Changed |= foldIsSPMDExecMode(A);
4529  break;
4530  case OMPRTL___kmpc_is_generic_main_thread_id:
4531  Changed |= foldIsGenericMainThread(A);
4532  break;
4533  case OMPRTL___kmpc_parallel_level:
4534  Changed |= foldParallelLevel(A);
4535  break;
4536  case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4537  Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
4538  break;
4539  case OMPRTL___kmpc_get_hardware_num_blocks:
4540  Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
4541  break;
4542  default:
4543  llvm_unreachable("Unhandled OpenMP runtime function!");
4544  }
4545 
4546  return Changed;
4547  }
4548 
4549  ChangeStatus manifest(Attributor &A) override {
4551 
4552  if (SimplifiedValue && *SimplifiedValue) {
4553  Instruction &I = *getCtxI();
4554  A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
4555  A.deleteAfterManifest(I);
4556 
4557  CallBase *CB = dyn_cast<CallBase>(&I);
4558  auto Remark = [&](OptimizationRemark OR) {
4559  if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
4560  return OR << "Replacing OpenMP runtime call "
4561  << CB->getCalledFunction()->getName() << " with "
4562  << ore::NV("FoldedValue", C->getZExtValue()) << ".";
4563  return OR << "Replacing OpenMP runtime call "
4564  << CB->getCalledFunction()->getName() << ".";
4565  };
4566 
4567  if (CB && EnableVerboseRemarks)
4568  A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
4569 
4570  LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
4571  << **SimplifiedValue << "\n");
4572 
4573  Changed = ChangeStatus::CHANGED;
4574  }
4575 
4576  return Changed;
4577  }
4578 
4579  ChangeStatus indicatePessimisticFixpoint() override {
4580  SimplifiedValue = nullptr;
4581  return AAFoldRuntimeCall::indicatePessimisticFixpoint();
4582  }
4583 
4584 private:
4585  /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
4586  ChangeStatus foldIsSPMDExecMode(Attributor &A) {
4587  Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4588 
4589  unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4590  unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4591  auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4592  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4593 
4594  if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4595  return indicatePessimisticFixpoint();
4596 
4597  for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4598  auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4600 
4601  if (!AA.isValidState()) {
4602  SimplifiedValue = nullptr;
4603  return indicatePessimisticFixpoint();
4604  }
4605 
4606  if (AA.SPMDCompatibilityTracker.isAssumed()) {
4607  if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4608  ++KnownSPMDCount;
4609  else
4610  ++AssumedSPMDCount;
4611  } else {
4612  if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4613  ++KnownNonSPMDCount;
4614  else
4615  ++AssumedNonSPMDCount;
4616  }
4617  }
4618 
4619  if ((AssumedSPMDCount + KnownSPMDCount) &&
4620  (AssumedNonSPMDCount + KnownNonSPMDCount))
4621  return indicatePessimisticFixpoint();
4622 
4623  auto &Ctx = getAnchorValue().getContext();
4624  if (KnownSPMDCount || AssumedSPMDCount) {
4625  assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4626  "Expected only SPMD kernels!");
4627  // All reaching kernels are in SPMD mode. Update all function calls to
4628  // __kmpc_is_spmd_exec_mode to 1.
4629  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4630  } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
4631  assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4632  "Expected only non-SPMD kernels!");
4633  // All reaching kernels are in non-SPMD mode. Update all function
4634  // calls to __kmpc_is_spmd_exec_mode to 0.
4635  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
4636  } else {
4637  // We have empty reaching kernels, therefore we cannot tell if the
4638  // associated call site can be folded. At this moment, SimplifiedValue
4639  // must be none.
4640  assert(!SimplifiedValue && "SimplifiedValue should be none");
4641  }
4642 
4643  return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4645  }
4646 
4647  /// Fold __kmpc_is_generic_main_thread_id into a constant if possible.
4648  ChangeStatus foldIsGenericMainThread(Attributor &A) {
4649  Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4650 
4651  CallBase &CB = cast<CallBase>(getAssociatedValue());
4652  Function *F = CB.getFunction();
4653  const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
4655 
4656  if (!ExecutionDomainAA.isValidState())
4657  return indicatePessimisticFixpoint();
4658 
4659  auto &Ctx = getAnchorValue().getContext();
4660  if (ExecutionDomainAA.isExecutedByInitialThreadOnly(CB))
4661  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4662  else
4663  return indicatePessimisticFixpoint();
4664 
4665  return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4667  }
4668 
4669  /// Fold __kmpc_parallel_level into a constant if possible.
4670  ChangeStatus foldParallelLevel(Attributor &A) {
4671  Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4672 
4673  auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4674  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4675 
4676  if (!CallerKernelInfoAA.ParallelLevels.isValidState())
4677  return indicatePessimisticFixpoint();
4678 
4679  if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4680  return indicatePessimisticFixpoint();
4681 
4682  if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
4683  assert(!SimplifiedValue &&
4684  "SimplifiedValue should keep none at this point");
4685  return ChangeStatus::UNCHANGED;
4686  }
4687 
4688  unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4689  unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4690  for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4691  auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4693  if (!AA.SPMDCompatibilityTracker.isValidState())
4694  return indicatePessimisticFixpoint();
4695 
4696  if (AA.SPMDCompatibilityTracker.isAssumed()) {
4697  if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4698  ++KnownSPMDCount;
4699  else
4700  ++AssumedSPMDCount;
4701  } else {
4702  if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4703  ++KnownNonSPMDCount;
4704  else
4705  ++AssumedNonSPMDCount;
4706  }
4707  }
4708 
4709  if ((AssumedSPMDCount + KnownSPMDCount) &&
4710  (AssumedNonSPMDCount + KnownNonSPMDCount))
4711  return indicatePessimisticFixpoint();
4712 
4713  auto &Ctx = getAnchorValue().getContext();
4714  // If the caller can only be reached by SPMD kernel entries, the parallel
4715  // level is 1. Similarly, if the caller can only be reached by non-SPMD
4716  // kernel entries, it is 0.
4717  if (AssumedSPMDCount || KnownSPMDCount) {
4718  assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4719  "Expected only SPMD kernels!");
4720  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
4721  } else {
4722  assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4723  "Expected only non-SPMD kernels!");
4724  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
4725  }
4726  return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4728  }
4729 
4730  ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
4731  // Specialize only if all the calls agree with the attribute constant value
4732  int32_t CurrentAttrValue = -1;
4733  Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4734 
4735  auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4736  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4737 
4738  if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4739  return indicatePessimisticFixpoint();
4740 
4741  // Iterate over the kernels that reach this function
4742  for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4743  int32_t NextAttrVal = -1;
4744  if (K->hasFnAttribute(Attr))
4745  NextAttrVal =
4746  std::stoi(K->getFnAttribute(Attr).getValueAsString().str());
4747 
4748  if (NextAttrVal == -1 ||
4749  (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
4750  return indicatePessimisticFixpoint();
4751  CurrentAttrValue = NextAttrVal;
4752  }
4753 
4754  if (CurrentAttrValue != -1) {
4755  auto &Ctx = getAnchorValue().getContext();
4756  SimplifiedValue =
4757  ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
4758  }
4759  return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4761  }
4762 
4763  /// An optional value the associated value is assumed to fold to. That is, we
4764  /// assume the associated value (which is a call) can be replaced by this
4765  /// simplified value.
4766  Optional<Value *> SimplifiedValue;
4767 
4768  /// The runtime function kind of the callee of the associated call site.
4769  RuntimeFunction RFKind;
4770 };
4771 
4772 } // namespace
4773 
4774 /// Register folding callsite
4775 void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
4776  auto &RFI = OMPInfoCache.RFIs[RF];
4777  RFI.foreachUse(SCC, [&](Use &U, Function &F) {
4778  CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
4779  if (!CI)
4780  return false;
4781  A.getOrCreateAAFor<AAFoldRuntimeCall>(
4782  IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
4783  DepClassTy::NONE, /* ForceUpdate */ false,
4784  /* UpdateAfterInit */ false);
4785  return false;
4786  });
4787 }
4788 
4789 void OpenMPOpt::registerAAs(bool IsModulePass) {
4790  if (SCC.empty())
4791  return;
4792 
4793  if (IsModulePass) {
4794  // Ensure we create the AAKernelInfo AAs first and without triggering an
4795  // update. This will make sure we register all value simplification
4796  // callbacks before any other AA has the chance to create an AAValueSimplify
4797  // or similar.
4798  auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
4799  A.getOrCreateAAFor<AAKernelInfo>(
4800  IRPosition::function(Kernel), /* QueryingAA */ nullptr,
4801  DepClassTy::NONE, /* ForceUpdate */ false,
4802  /* UpdateAfterInit */ false);
4803  return false;
4804  };
4805  OMPInformationCache::RuntimeFunctionInfo &InitRFI =
4806  OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
4807  InitRFI.foreachUse(SCC, CreateKernelInfoCB);
4808 
4809  registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id);
4810  registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
4811  registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
4812  registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
4813  registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
4814  }
4815 
4816  // Create CallSite AA for all Getters.
4817  for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
4818  auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
4819 
4820  auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
4821 
4822  auto CreateAA = [&](Use &U, Function &Caller) {
4823  CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
4824  if (!CI)
4825  return false;
4826 
4827  auto &CB = cast<CallBase>(*CI);
4828 
4830  A.getOrCreateAAFor<AAICVTracker>(CBPos);
4831  return false;
4832  };
4833 
4834  GetterRFI.foreachUse(SCC, CreateAA);
4835  }
4836  auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4837  auto CreateAA = [&](Use &U, Function &F) {
4838  A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
4839  return false;
4840  };
4842  GlobalizationRFI.foreachUse(SCC, CreateAA);
4843 
4844  // Create an ExecutionDomain AA for every function and a HeapToStack AA for
4845  // every function if there is a device kernel.
4846  if (!isOpenMPDevice(M))
4847  return;
4848 
4849  for (auto *F : SCC) {
4850  if (F->isDeclaration())
4851  continue;
4852 
4853  A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F));
4855  A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F));
4856 
4857  for (auto &I : instructions(*F)) {
4858  if (auto *LI = dyn_cast<LoadInst>(&I)) {
4859  bool UsedAssumedInformation = false;
4860  A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
4861  UsedAssumedInformation, AA::Interprocedural);
4862  } else if (auto *SI = dyn_cast<StoreInst>(&I)) {
4863  A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
4864  }
4865  }
4866  }
4867 }
4868 
4869 const char AAICVTracker::ID = 0;
4870 const char AAKernelInfo::ID = 0;
4871 const char AAExecutionDomain::ID = 0;
4872 const char AAHeapToShared::ID = 0;
4873 const char AAFoldRuntimeCall::ID = 0;
4874 
4875 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
4876  Attributor &A) {
4877  AAICVTracker *AA = nullptr;
4878  switch (IRP.getPositionKind()) {
4880  case IRPosition::IRP_FLOAT:
4883  llvm_unreachable("ICVTracker can only be created for function position!");
4885  AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
4886  break;
4888  AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
4889  break;
4891  AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
4892  break;
4894  AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
4895  break;
4896  }
4897 
4898  return *AA;
4899 }
4900 
4902  Attributor &A) {
4903  AAExecutionDomainFunction *AA = nullptr;
4904  switch (IRP.getPositionKind()) {
4906  case IRPosition::IRP_FLOAT:
4913  "AAExecutionDomain can only be created for function position!");
4915  AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
4916  break;
4917  }
4918 
4919  return *AA;
4920 }
4921 
4922 AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
4923  Attributor &A) {
4924  AAHeapToSharedFunction *AA = nullptr;
4925  switch (IRP.getPositionKind()) {
4927  case IRPosition::IRP_FLOAT:
4934  "AAHeapToShared can only be created for function position!");
4936  AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
4937  break;
4938  }
4939 
4940  return *AA;
4941 }
4942 
4943 AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
4944  Attributor &A) {
4945  AAKernelInfo *AA = nullptr;
4946  switch (IRP.getPositionKind()) {
4948  case IRPosition::IRP_FLOAT:
4953  llvm_unreachable("KernelInfo can only be created for function position!");
4955  AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
4956  break;
4958  AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
4959  break;
4960  }
4961 
4962  return *AA;
4963 }
4964 
4965 AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
4966  Attributor &A) {
4967  AAFoldRuntimeCall *AA = nullptr;
4968  switch (IRP.getPositionKind()) {
4970  case IRPosition::IRP_FLOAT:
4976  llvm_unreachable("KernelInfo can only be created for call site position!");
4978  AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
4979  break;
4980  }
4981 
4982  return *AA;
4983 }
4984 
4986  if (!containsOpenMP(M))
4987  return PreservedAnalyses::all();
4989  return PreservedAnalyses::all();
4990 
4994 
4996  LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
4997 
4998  auto IsCalled = [&](Function &F) {
4999  if (Kernels.contains(&F))
5000  return true;
5001  for (const User *U : F.users())
5002  if (!isa<BlockAddress>(U))
5003  return true;
5004  return false;
5005  };
5006 
5007  auto EmitRemark = [&](Function &F) {
5009  ORE.emit([&]() {
5010  OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5011  return ORA << "Could not internalize function. "
5012  << "Some optimizations may not be possible. [OMP140]";
5013  });
5014  };
5015 
5016  // Create internal copies of each function if this is a kernel Module. This
5017  // allows iterprocedural passes to see every call edge.
5018  DenseMap<Function *, Function *> InternalizedMap;
5019  if (isOpenMPDevice(M)) {
5020  SmallPtrSet<Function *, 16> InternalizeFns;
5021  for (Function &F : M)
5022  if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
5025  InternalizeFns.insert(&F);
5026  } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5027  EmitRemark(F);
5028  }
5029  }
5030 
5031  Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
5032  }
5033 
5034  // Look at every function in the Module unless it was internalized.
5036  for (Function &F : M)
5037  if (!F.isDeclaration() && !InternalizedMap.lookup(&F))
5038  SCC.push_back(&F);
5039 
5040  if (SCC.empty())
5041  return PreservedAnalyses::all();
5042 
5043  AnalysisGetter AG(FAM);
5044 
5045  auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5047  };
5048 
5050  CallGraphUpdater CGUpdater;
5051 
5052  OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, Kernels);
5053 
5054  unsigned MaxFixpointIterations =
5056 
5057  AttributorConfig AC(CGUpdater);
5058  AC.DefaultInitializeLiveInternals = false;
5059  AC.RewriteSignatures = false;
5060  AC.MaxFixpointIterations = MaxFixpointIterations;
5061  AC.OREGetter = OREGetter;
5062  AC.PassName = DEBUG_TYPE;
5063 
5064  SetVector<Function *> Functions;
5065  Attributor A(Functions, InfoCache, AC);
5066 
5067  OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5068  bool Changed = OMPOpt.run(true);
5069 
5070  // Optionally inline device functions for potentially better performance.
5072  for (Function &F : M)
5073  if (!F.isDeclaration() && !Kernels.contains(&F) &&
5074  !F.hasFnAttribute(Attribute::NoInline))
5075  F.addFnAttr(Attribute::AlwaysInline);
5076 
5078  LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5079 
5080  if (Changed)
5081  return PreservedAnalyses::none();
5082 
5083  return PreservedAnalyses::all();
5084 }
5085 
5088  LazyCallGraph &CG,
5089  CGSCCUpdateResult &UR) {
5090  if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5091  return PreservedAnalyses::all();
5093  return PreservedAnalyses::all();
5094 
5096  // If there are kernels in the module, we have to run on all SCC's.
5097  for (LazyCallGraph::Node &N : C) {
5098  Function *Fn = &N.getFunction();
5099  SCC.push_back(Fn);
5100  }
5101 
5102  if (SCC.empty())
5103  return PreservedAnalyses::all();
5104 
5105  Module &M = *C.begin()->getFunction().getParent();
5106 
5108  LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5109 
5111 
5113  AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5114 
5115  AnalysisGetter AG(FAM);
5116 
5117  auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5119  };
5120 
5122  CallGraphUpdater CGUpdater;
5123  CGUpdater.initialize(CG, C, AM, UR);
5124 
5125  SetVector<Function *> Functions(SCC.begin(), SCC.end());
5126  OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5127  /*CGSCC*/ &Functions, Kernels);
5128 
5129  unsigned MaxFixpointIterations =
5131 
5132  AttributorConfig AC(CGUpdater);
5133  AC.DefaultInitializeLiveInternals = false;
5134  AC.IsModulePass = false;
5135  AC.RewriteSignatures = false;
5136  AC.MaxFixpointIterations = MaxFixpointIterations;
5137  AC.OREGetter = OREGetter;
5138  AC.PassName = DEBUG_TYPE;
5139 
5140  Attributor A(Functions, InfoCache, AC);
5141 
5142  OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5143  bool Changed = OMPOpt.run(false);
5144 
5146  LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5147 
5148  if (Changed)
5149  return PreservedAnalyses::none();
5150 
5151  return PreservedAnalyses::all();
5152 }
5153 
5154 namespace {
5155 
5156 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass {
5157  CallGraphUpdater CGUpdater;
5158  static char ID;
5159 
5160  OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) {
5162  }
5163 
5164  void getAnalysisUsage(AnalysisUsage &AU) const override {
5166  }
5167 
5168  bool runOnSCC(CallGraphSCC &CGSCC) override {
5169  if (!containsOpenMP(CGSCC.getCallGraph().getModule()))
5170  return false;
5171  if (DisableOpenMPOptimizations || skipSCC(CGSCC))
5172  return false;
5173 
5175  // If there are kernels in the module, we have to run on all SCC's.
5176  for (CallGraphNode *CGN : CGSCC) {
5177  Function *Fn = CGN->getFunction();
5178  if (!Fn || Fn->isDeclaration())
5179  continue;
5180  SCC.push_back(Fn);
5181  }
5182 
5183  if (SCC.empty())
5184  return false;
5185 
5186  Module &M = CGSCC.getCallGraph().getModule();
5188 
5189  CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
5190  CGUpdater.initialize(CG, CGSCC);
5191 
5192  // Maintain a map of functions to avoid rebuilding the ORE
5194  auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
5195  std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
5196  if (!ORE)
5197  ORE = std::make_unique<OptimizationRemarkEmitter>(F);
5198  return *ORE;
5199  };
5200 
5201  AnalysisGetter AG;
5202  SetVector<Function *> Functions(SCC.begin(), SCC.end());
5204  OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
5205  Allocator,
5206  /*CGSCC*/ &Functions, Kernels);
5207 
5208  unsigned MaxFixpointIterations =
5210 
5211  AttributorConfig AC(CGUpdater);
5212  AC.DefaultInitializeLiveInternals = false;
5213  AC.IsModulePass = false;
5214  AC.RewriteSignatures = false;
5215  AC.MaxFixpointIterations = MaxFixpointIterations;
5216  AC.OREGetter = OREGetter;
5217  AC.PassName = DEBUG_TYPE;
5218 
5219  Attributor A(Functions, InfoCache, AC);
5220 
5221  OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5222  bool Result = OMPOpt.run(false);
5223 
5225  LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5226 
5227  return Result;
5228  }
5229 
5230  bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
5231 };
5232 
5233 } // end anonymous namespace
5234 
5236  // TODO: Create a more cross-platform way of determining device kernels.
5237  NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
5239 
5240  if (!MD)
5241  return Kernels;
5242 
5243  for (auto *Op : MD->operands()) {
5244  if (Op->getNumOperands() < 2)
5245  continue;
5246  MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
5247  if (!KindID || KindID->getString() != "kernel")
5248  continue;
5249 
5250  Function *KernelFn =
5251  mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
5252  if (!KernelFn)
5253  continue;
5254 
5255  ++NumOpenMPTargetRegionKernels;
5256 
5257  Kernels.insert(KernelFn);
5258  }
5259 
5260  return Kernels;
5261 }
5262 
5264  Metadata *MD = M.getModuleFlag("openmp");
5265  if (!MD)
5266  return false;
5267 
5268  return true;
5269 }
5270 
5272  Metadata *MD = M.getModuleFlag("openmp-device");
5273  if (!MD)
5274  return false;
5275 
5276  return true;
5277 }
5278 
5280 
5281 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
5282  "OpenMP specific optimizations", false, false)
5284 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
5285  "OpenMP specific optimizations", false, false)
5286 
5288  return new OpenMPOptCGSCCLegacyPass();
5289 }
llvm::AA::isValidAtPosition
bool isValidAtPosition(const ValueAndContext &VAC, InformationCache &InfoCache)
Return true if the value of VAC is a valid at the position of VAC, that is a constant,...
Definition: Attributor.cpp:247
llvm::PreservedAnalyses
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
llvm::CallGraphUpdater
Wrapper to unify "old style" CallGraph and "new style" LazyCallGraph.
Definition: CallGraphUpdater.h:29
EnableVerboseRemarks
static cl::opt< bool > EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::desc("Enables more verbose remarks."), cl::Hidden, cl::init(false))
llvm::Argument
This class represents an incoming formal argument to a Function.
Definition: Argument.h:28
llvm::IRPosition::function
static const IRPosition function(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the function scope of F.
Definition: Attributor.h:470
llvm::BasicBlock::end
iterator end()
Definition: BasicBlock.h:308
llvm::OptimizationRemarkMissed
Diagnostic information for missed-optimization remarks.
Definition: DiagnosticInfo.h:735
llvm::OpenMPIRBuilder::LocationDescription
Description of a LLVM-IR insertion point (IP) and a debug/source location (filename,...
Definition: OMPIRBuilder.h:203
getName
static StringRef getName(Value *V)
Definition: ProvenanceAnalysisEvaluator.cpp:42
SetFixpointIterations
static cl::opt< unsigned > SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden, cl::desc("Maximal number of attributor iterations."), cl::init(256))
MI
IRTranslator LLVM IR MI
Definition: IRTranslator.cpp:108
Merge
R600 Clause Merge
Definition: R600ClauseMergePass.cpp:70
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
BlockSize
static const int BlockSize
Definition: TarWriter.cpp:33
llvm::CastInst::CreatePointerBitCastOrAddrSpaceCast
static CastInst * CreatePointerBitCastOrAddrSpaceCast(Value *S, Type *Ty, const Twine &Name, BasicBlock *InsertAtEnd)
Create a BitCast or an AddrSpaceCast cast instruction.
Definition: Instructions.cpp:3379
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::Instruction::getModule
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
Definition: Instruction.cpp:65
llvm::CmpInst::ICMP_EQ
@ ICMP_EQ
equal
Definition: InstrTypes.h:740
llvm::NamedMDNode
A tuple of MDNodes.
Definition: Metadata.h:1588
llvm::drop_begin
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:268
OpenMPOpt.h
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:113
llvm::ISD::OR
@ OR
Definition: ISDOpcodes.h:667
llvm::Value::hasOneUse
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
llvm::Type::getInt8PtrTy
static PointerType * getInt8PtrTy(LLVMContext &C, unsigned AS=0)
Definition: Type.cpp:291
DisableOpenMPOptimizations
static cl::opt< bool > DisableOpenMPOptimizations("openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, cl::init(false))
llvm::BasicBlock::getParent
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:104
IntrinsicInst.h
llvm::AnalysisManager::getResult
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:774
llvm::cl::Prefix
@ Prefix
Definition: CommandLine.h:161
llvm::MemTransferInst
This class wraps the llvm.memcpy/memmove intrinsics.
Definition: IntrinsicInst.h:1026
llvm::Function
Definition: Function.h:60
DisableOpenMPOptDeglobalization
static cl::opt< bool > DisableOpenMPOptDeglobalization("openmp-opt-disable-deglobalization", cl::desc("Disable OpenMP optimizations involving deglobalization."), cl::Hidden, cl::init(false))
llvm::Attribute
Definition: Attributes.h:65
llvm::DenseMapBase< DenseMap< KeyT, ValueT, DenseMapInfo< KeyT >, llvm::detail::DenseMapPair< KeyT, ValueT > >, KeyT, ValueT, DenseMapInfo< KeyT >, llvm::detail::DenseMapPair< KeyT, ValueT > >::lookup
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:197
StringRef.h
P
This currently compiles esp xmm0 movsd esp eax eax esp ret We should use not the dag combiner This is because dagcombine2 needs to be able to see through the X86ISD::Wrapper which DAGCombine can t really do The code for turning x load into a single vector load is target independent and should be moved to the dag combiner The code for turning x load into a vector load can only handle a direct load from a global or a direct load from the stack It should be generalized to handle any load from P
Definition: README-SSE.txt:411
TAG
static constexpr auto TAG
Definition: OpenMPOpt.cpp:172
llvm::raw_string_ostream
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:628
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::size
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:77
contains
return AArch64::GPR64RegClass contains(Reg)
PrintModuleAfterOptimizations
static cl::opt< bool > PrintModuleAfterOptimizations("openmp-opt-print-module-after", cl::desc("Print the current module after OpenMP optimizations."), cl::Hidden, cl::init(false))
llvm::GlobalValue::NotThreadLocal
@ NotThreadLocal
Definition: GlobalValue.h:192
llvm::ilist_node_with_parent::getNextNode
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition: ilist_node.h:289
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1181
Statistic.h
llvm::initializeOpenMPOptCGSCCLegacyPassPass
void initializeOpenMPOptCGSCCLegacyPassPass(PassRegistry &)
llvm::IRPosition::value
static const IRPosition value(const Value &V, const CallBaseContext *CBContext=nullptr)
Create a position describing the value of V.
Definition: Attributor.h:451
llvm::OpenMPOptCGSCCPass::run
PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR)
Definition: OpenMPOpt.cpp:5086
llvm::Type::getPointerAddressSpace
unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
Definition: DerivedTypes.h:729
llvm::Function::getEntryBlock
const BasicBlock & getEntryBlock() const
Definition: Function.h:710
llvm::OpenMPIRBuilder::InsertPointTy
IRBuilder<>::InsertPoint InsertPointTy
Type used throughout for insertion points.
Definition: OMPIRBuilder.h:97
llvm::IRBuilder
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2524
llvm::StateWrapper
Helper to tie a abstract state implementation to an abstract attribute.
Definition: Attributor.h:2821
llvm::GlobalVariable
Definition: GlobalVariable.h:39
llvm::SmallDenseMap
Definition: DenseMap.h:880
llvm::CmpInst::ICMP_NE
@ ICMP_NE
not equal
Definition: InstrTypes.h:741
llvm::FunctionType::get
static FunctionType * get(Type *Result, ArrayRef< Type * > Params, bool isVarArg)
This static method is the primary way of constructing a FunctionType.
Definition: Type.cpp:361
ValueTracking.h
OptimizationRemarkEmitter.h
llvm::CallGraph
The basic data container for the call graph of a Module of IR.
Definition: CallGraph.h:72
llvm::AA::Interprocedural
@ Interprocedural
Definition: Attributor.h:159
llvm::DominatorTree
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:166
FAM
FunctionAnalysisManager FAM
Definition: PassBuilderBindings.cpp:59
llvm::cl::Hidden
@ Hidden
Definition: CommandLine.h:140
llvm::DenseMapBase< DenseMap< KeyT, ValueT, DenseMapInfo< KeyT >, llvm::detail::DenseMapPair< KeyT, ValueT > >, KeyT, ValueT, DenseMapInfo< KeyT >, llvm::detail::DenseMapPair< KeyT, ValueT > >::begin
iterator begin()
Definition: DenseMap.h:75
llvm::PreservedAnalyses::none
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition: PassManager.h:155
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
llvm::sys::path::end
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:235
llvm::sys::path::begin
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:226
llvm::CallBase::isCallee
bool isCallee(Value::const_user_iterator UI) const
Determine whether the passed iterator points to the callee operand's Use.
Definition: InstrTypes.h:1407
llvm::Use::getOperandNo
unsigned getOperandNo() const
Return the operand # of this use in its User.
Definition: Use.cpp:31
llvm::BasicBlock::splitBasicBlock
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="", bool Before=false)
Split the basic block into two basic blocks at the specified instruction.
Definition: BasicBlock.cpp:402
llvm::operator!=
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:1992
llvm::InformationCache
Data structure to hold cached (LLVM-IR) information.
Definition: Attributor.h:996
llvm::MemIntrinsic
This is the common base class for memset/memcpy/memmove.
Definition: IntrinsicInst.h:961
llvm::AttributorConfig
Configuration for the Attributor.
Definition: Attributor.h:1242
llvm::Optional
Definition: APInt.h:33
llvm::SmallPtrSet< Instruction *, 4 >
llvm::ore::NV
DiagnosticInfoOptimizationBase::Argument NV
Definition: OptimizationRemarkEmitter.h:136
llvm::tgtok::FalseVal
@ FalseVal
Definition: TGLexer.h:62
llvm::omp::AddressSpace::Constant
@ Constant
llvm::max
Expected< ExpressionValue > max(const ExpressionValue &Lhs, const ExpressionValue &Rhs)
Definition: FileCheck.cpp:337
llvm::IRPosition::IRP_ARGUMENT
@ IRP_ARGUMENT
An attribute for a function argument.
Definition: Attributor.h:441
llvm::IRPosition::IRP_RETURNED
@ IRP_RETURNED
An attribute for the function return value.
Definition: Attributor.h:437
llvm::MipsISD::Ret
@ Ret
Definition: MipsISelLowering.h:119
initialize
static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, ArrayRef< StringLiteral > StandardNames)
Initialize the set of available library functions based on the specified target triple.
Definition: TargetLibraryInfo.cpp:150
llvm::AbstractCallSite
AbstractCallSite.
Definition: AbstractCallSite.h:50
RHS
Value * RHS
Definition: X86PartialReduction.cpp:76
llvm::GlobalValue::LinkageTypes
LinkageTypes
An enumeration for the kinds of linkage for global values.
Definition: GlobalValue.h:47
llvm::AAExecutionDomain::ID
static const char ID
Unique ID (due to the unique address)
Definition: Attributor.h:4861
llvm::CallBase::addParamAttr
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind)
Adds the attribute to the indicated argument.
Definition: InstrTypes.h:1526
DisableOpenMPOptBarrierElimination
static cl::opt< bool > DisableOpenMPOptBarrierElimination("openmp-opt-disable-barrier-elimination", cl::desc("Disable OpenMP optimizations that eliminate barriers."), cl::Hidden, cl::init(false))
llvm::Type::getInt8Ty
static IntegerType * getInt8Ty(LLVMContext &C)
Definition: Type.cpp:237
llvm::IRPosition::IRP_FLOAT
@ IRP_FLOAT
A position that is not associated with a spot suitable for attributes.
Definition: Attributor.h:435
llvm::Type::getInt32Ty
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:239
LLVM_DEBUG
#define LLVM_DEBUG(X)
Definition: Debug.h:101
llvm::Instruction::mayHaveSideEffects
bool mayHaveSideEffects() const
Return true if the instruction may have side effects.
Definition: Instruction.cpp:721
PrintICVValues
static cl::opt< bool > PrintICVValues("openmp-print-icv-values", cl::init(false), cl::Hidden)
F
#define F(x, y, z)
Definition: MD5.cpp:55
llvm::IntegerStateBase< bool, true, false >::operator^=
void operator^=(const IntegerStateBase< base_t, BestState, WorstState > &R)
"Clamp" this state with R.
Definition: Attributor.h:2347
llvm::ConstantExpr::getPointerCast
static Constant * getPointerCast(Constant *C, Type *Ty)
Create a BitCast, AddrSpaceCast, or a PtrToInt cast constant expression.
Definition: Constants.cpp:2014
llvm::BasicBlock::getUniqueSuccessor
const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
Definition: BasicBlock.cpp:323
DisableOpenMPOptSPMDization
static cl::opt< bool > DisableOpenMPOptSPMDization("openmp-opt-disable-spmdization", cl::desc("Disable OpenMP optimizations involving SPMD-ization."), cl::Hidden, cl::init(false))
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:55
llvm::AMDGPU::isKernel
LLVM_READNONE bool isKernel(CallingConv::ID CC)
Definition: AMDGPUBaseInfo.h:817
llvm::dbgs
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
llvm::CallGraphUpdater::finalize
bool finalize()
}
Definition: CallGraphUpdater.cpp:24
Arg
amdgpu Simplify well known AMD library false FunctionCallee Value * Arg
Definition: AMDGPULibCalls.cpp:187
Instruction.h
llvm::AMDGPU::HSAMD::Key::Kernels
constexpr char Kernels[]
Key for HSA::Metadata::mKernels.
Definition: AMDGPUMetadata.h:432
CommandLine.h
Printer
print alias Alias Set Printer
Definition: AliasSetTracker.cpp:734
llvm::ConstantInt
This is the shared class of boolean and integer constants.
Definition: Constants.h:79
llvm::all_of
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1590
llvm::operator&=
bool operator&=(SparseBitVector< ElementSize > *LHS, const SparseBitVector< ElementSize > &RHS)
Definition: SparseBitVector.h:835