LLVM  15.0.0git
AMDGPULowerKernelAttributes.cpp
Go to the documentation of this file.
1 //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file This pass does attempts to make use of reqd_work_group_size metadata
10 /// to eliminate loads from the dispatch packet and to constant fold OpenCL
11 /// get_local_size-like functions.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "AMDGPU.h"
17 #include "llvm/CodeGen/Passes.h"
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/IR/InstIterator.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/IntrinsicsAMDGPU.h"
24 #include "llvm/IR/PatternMatch.h"
25 #include "llvm/Pass.h"
26 
27 #define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
28 
29 using namespace llvm;
30 
31 namespace {
32 
33 // Field offsets in hsa_kernel_dispatch_packet_t.
34 enum DispatchPackedOffsets {
35  WORKGROUP_SIZE_X = 4,
36  WORKGROUP_SIZE_Y = 6,
37  WORKGROUP_SIZE_Z = 8,
38 
39  GRID_SIZE_X = 12,
40  GRID_SIZE_Y = 16,
41  GRID_SIZE_Z = 20
42 };
43 
44 class AMDGPULowerKernelAttributes : public ModulePass {
45 public:
46  static char ID;
47 
48  AMDGPULowerKernelAttributes() : ModulePass(ID) {}
49 
50  bool runOnModule(Module &M) override;
51 
52  StringRef getPassName() const override {
53  return "AMDGPU Kernel Attributes";
54  }
55 
56  void getAnalysisUsage(AnalysisUsage &AU) const override {
57  AU.setPreservesAll();
58  }
59 };
60 
61 } // end anonymous namespace
62 
63 static bool processUse(CallInst *CI) {
64  Function *F = CI->getParent()->getParent();
65 
66  auto MD = F->getMetadata("reqd_work_group_size");
67  const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;
68 
69  const bool HasUniformWorkGroupSize =
70  F->getFnAttribute("uniform-work-group-size").getValueAsBool();
71 
72  if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize)
73  return false;
74 
75  Value *WorkGroupSizeX = nullptr;
76  Value *WorkGroupSizeY = nullptr;
77  Value *WorkGroupSizeZ = nullptr;
78 
79  Value *GridSizeX = nullptr;
80  Value *GridSizeY = nullptr;
81  Value *GridSizeZ = nullptr;
82 
83  const DataLayout &DL = F->getParent()->getDataLayout();
84 
85  // We expect to see several GEP users, casted to the appropriate type and
86  // loaded.
87  for (User *U : CI->users()) {
88  if (!U->hasOneUse())
89  continue;
90 
91  int64_t Offset = 0;
92  if (GetPointerBaseWithConstantOffset(U, Offset, DL) != CI)
93  continue;
94 
95  auto *BCI = dyn_cast<BitCastInst>(*U->user_begin());
96  if (!BCI || !BCI->hasOneUse())
97  continue;
98 
99  auto *Load = dyn_cast<LoadInst>(*BCI->user_begin());
100  if (!Load || !Load->isSimple())
101  continue;
102 
103  unsigned LoadSize = DL.getTypeStoreSize(Load->getType());
104 
105  // TODO: Handle merged loads.
106  switch (Offset) {
107  case WORKGROUP_SIZE_X:
108  if (LoadSize == 2)
109  WorkGroupSizeX = Load;
110  break;
111  case WORKGROUP_SIZE_Y:
112  if (LoadSize == 2)
113  WorkGroupSizeY = Load;
114  break;
115  case WORKGROUP_SIZE_Z:
116  if (LoadSize == 2)
117  WorkGroupSizeZ = Load;
118  break;
119  case GRID_SIZE_X:
120  if (LoadSize == 4)
121  GridSizeX = Load;
122  break;
123  case GRID_SIZE_Y:
124  if (LoadSize == 4)
125  GridSizeY = Load;
126  break;
127  case GRID_SIZE_Z:
128  if (LoadSize == 4)
129  GridSizeZ = Load;
130  break;
131  default:
132  break;
133  }
134  }
135 
136  // Pattern match the code used to handle partial workgroup dispatches in the
137  // library implementation of get_local_size, so the entire function can be
138  // constant folded with a known group size.
139  //
140  // uint r = grid_size - group_id * group_size;
141  // get_local_size = (r < group_size) ? r : group_size;
142  //
143  // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
144  // the grid_size is required to be a multiple of group_size). In this case:
145  //
146  // grid_size - (group_id * group_size) < group_size
147  // ->
148  // grid_size < group_size + (group_id * group_size)
149  //
150  // (grid_size / group_size) < 1 + group_id
151  //
152  // grid_size / group_size is at least 1, so we can conclude the select
153  // condition is false (except for group_id == 0, where the select result is
154  // the same).
155 
156  bool MadeChange = false;
157  Value *WorkGroupSizes[3] = { WorkGroupSizeX, WorkGroupSizeY, WorkGroupSizeZ };
158  Value *GridSizes[3] = { GridSizeX, GridSizeY, GridSizeZ };
159 
160  for (int I = 0; HasUniformWorkGroupSize && I < 3; ++I) {
161  Value *GroupSize = WorkGroupSizes[I];
162  Value *GridSize = GridSizes[I];
163  if (!GroupSize || !GridSize)
164  continue;
165 
166  using namespace llvm::PatternMatch;
167  auto GroupIDIntrin =
168  I == 0 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>()
169  : (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>()
170  : m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
171 
172  for (User *U : GroupSize->users()) {
173  auto *ZextGroupSize = dyn_cast<ZExtInst>(U);
174  if (!ZextGroupSize)
175  continue;
176 
177  for (User *UMin : ZextGroupSize->users()) {
178  if (match(UMin,
179  m_UMin(m_Sub(m_Specific(GridSize),
180  m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize))),
181  m_Specific(ZextGroupSize)))) {
182  if (HasReqdWorkGroupSize) {
183  ConstantInt *KnownSize
184  = mdconst::extract<ConstantInt>(MD->getOperand(I));
185  UMin->replaceAllUsesWith(ConstantExpr::getIntegerCast(
186  KnownSize, UMin->getType(), false));
187  } else {
188  UMin->replaceAllUsesWith(ZextGroupSize);
189  }
190 
191  MadeChange = true;
192  }
193  }
194  }
195  }
196 
197  if (!HasReqdWorkGroupSize)
198  return MadeChange;
199 
200  // Eliminate any other loads we can from the dispatch packet.
201  for (int I = 0; I < 3; ++I) {
202  Value *GroupSize = WorkGroupSizes[I];
203  if (!GroupSize)
204  continue;
205 
206  ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I));
207  GroupSize->replaceAllUsesWith(
209  GroupSize->getType(),
210  false));
211  MadeChange = true;
212  }
213 
214  return MadeChange;
215 }
216 
217 // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
218 // TargetPassConfig for subtarget.
219 bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {
220  StringRef DispatchPtrName
221  = Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr);
222 
223  Function *DispatchPtr = M.getFunction(DispatchPtrName);
224  if (!DispatchPtr) // Dispatch ptr not used.
225  return false;
226 
227  bool MadeChange = false;
228 
229  SmallPtrSet<Instruction *, 4> HandledUses;
230  for (auto *U : DispatchPtr->users()) {
231  CallInst *CI = cast<CallInst>(U);
232  if (HandledUses.insert(CI).second) {
233  if (processUse(CI))
234  MadeChange = true;
235  }
236  }
237 
238  return MadeChange;
239 }
240 
241 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE,
242  "AMDGPU Kernel Attributes", false, false)
243 INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE,
245 
246 char AMDGPULowerKernelAttributes::ID = 0;
247 
249  return new AMDGPULowerKernelAttributes();
250 }
251 
254  StringRef DispatchPtrName =
255  Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr);
256 
257  Function *DispatchPtr = F.getParent()->getFunction(DispatchPtrName);
258  if (!DispatchPtr) // Dispatch ptr not used.
259  return PreservedAnalyses::all();
260 
261  for (Instruction &I : instructions(F)) {
262  if (CallInst *CI = dyn_cast<CallInst>(&I)) {
263  if (CI->getCalledFunction() == DispatchPtr)
264  processUse(CI);
265  }
266  }
267 
268  return PreservedAnalyses::all();
269 }
llvm::PreservedAnalyses
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:17
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:113
llvm::ModulePass
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:248
llvm::BasicBlock::getParent
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:104
InstIterator.h
llvm::Function
Definition: Function.h:60
Pass.h
llvm::Intrinsic::getName
StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
Definition: Function.cpp:879
ValueTracking.h
llvm::SmallPtrSet< Instruction *, 4 >
llvm::AMDGPULowerKernelAttributesPass::run
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Definition: AMDGPULowerKernelAttributes.cpp:253
F
#define F(x, y, z)
Definition: MD5.cpp:55
llvm::SPII::Load
@ Load
Definition: SparcInstrInfo.h:32
llvm::ConstantInt
This is the shared class of boolean and integer constants.
Definition: Constants.h:79
Constants.h
llvm::PatternMatch::match
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
llvm::User
Definition: User.h:44
llvm::CallBase::getCalledFunction
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1396
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
false
Definition: StackSlotColoring.cpp:141
AMDGPU
Definition: AMDGPUReplaceLDSUseWithPointer.cpp:114
llvm::Instruction
Definition: Instruction.h:42
llvm::PatternMatch::m_UMin
MaxMin_match< ICmpInst, LHS, RHS, umin_pred_ty > m_UMin(const LHS &L, const RHS &R)
Definition: PatternMatch.h:1865
PatternMatch.h
llvm::CallingConv::ID
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
INITIALIZE_PASS_END
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:58
llvm::RecurKind::UMin
@ UMin
Unisgned integer min implemented in terms of select(cmp()).
Passes.h
llvm::instructions
inst_range instructions(Function *F)
Definition: InstIterator.h:133
processUse
static bool processUse(CallInst *CI)
Definition: AMDGPULowerKernelAttributes.cpp:63
I
#define I(x, y, z)
Definition: MD5.cpp:58
Attributes
AMDGPU Kernel Attributes
Definition: AMDGPULowerKernelAttributes.cpp:244
TargetPassConfig.h
llvm::PatternMatch::m_Sub
BinaryOp_match< LHS, RHS, Instruction::Sub > m_Sub(const LHS &L, const RHS &R)
Definition: PatternMatch.h:985
llvm::Module
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
llvm::StringRef
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:58
AMDGPU.h
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
llvm::Value::replaceAllUsesWith
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:529
DL
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Definition: AArch64SLSHardening.cpp:76
llvm::ConstantExpr::getIntegerCast
static Constant * getIntegerCast(Constant *C, Type *Ty, bool IsSigned)
Create a ZExt, Bitcast or Trunc for integer -> integer casts.
Definition: Constants.cpp:2104
llvm::createAMDGPULowerKernelAttributesPass
ModulePass * createAMDGPULowerKernelAttributesPass()
Definition: AMDGPULowerKernelAttributes.cpp:248
llvm::AnalysisUsage::setPreservesAll
void setPreservesAll()
Set by analyses that do not transform their input at all.
Definition: PassAnalysisSupport.h:130
llvm::PreservedAnalyses::all
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:158
Function.h
llvm::GetPointerBaseWithConstantOffset
Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL, bool AllowNonInbounds=true)
Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset.
Definition: ValueTracking.h:278
Instructions.h
llvm::PatternMatch::m_Specific
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
Definition: PatternMatch.h:766
DEBUG_TYPE
#define DEBUG_TYPE
Definition: AMDGPULowerKernelAttributes.cpp:27
llvm::Instruction::getParent
const BasicBlock * getParent() const
Definition: Instruction.h:91
llvm::PatternMatch
Definition: PatternMatch.h:47
llvm::AnalysisManager
A container for analyses that lazily runs them and caches their results.
Definition: InstructionSimplify.h:42
llvm::CallInst
This class represents a function call, abstracting a target machine's calling convention.
Definition: Instructions.h:1461
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE, "AMDGPU Kernel Attributes", false, false) INITIALIZE_PASS_END(AMDGPULowerKernelAttributes
llvm::Value
LLVM Value Representation.
Definition: Value.h:74
llvm::Value::users
iterator_range< user_iterator > users()
Definition: Value.h:421
llvm::PatternMatch::m_Mul
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
Definition: PatternMatch.h:1039
llvm::SmallPtrSetImpl::insert
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:365
llvm::Intrinsic::ID
unsigned ID
Definition: TargetTransformInfo.h:38