LLVM 22.0.0git
AMDGPULowerKernelAttributes.cpp
Go to the documentation of this file.
1//===-- AMDGPULowerKernelAttributes.cpp------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file This pass does attempts to make use of reqd_work_group_size metadata
10/// to eliminate loads from the dispatch packet and to constant fold OpenCL
11/// get_local_size-like functions.
12//
13//===----------------------------------------------------------------------===//
14
15#include "AMDGPU.h"
19#include "llvm/CodeGen/Passes.h"
20#include "llvm/IR/Constants.h"
21#include "llvm/IR/Function.h"
22#include "llvm/IR/IRBuilder.h"
25#include "llvm/IR/IntrinsicsAMDGPU.h"
26#include "llvm/IR/MDBuilder.h"
28#include "llvm/Pass.h"
29
30#define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
31
32using namespace llvm;
33
34namespace {
35
36// Field offsets in hsa_kernel_dispatch_packet_t.
37enum DispatchPackedOffsets {
38 WORKGROUP_SIZE_X = 4,
39 WORKGROUP_SIZE_Y = 6,
40 WORKGROUP_SIZE_Z = 8,
41
42 GRID_SIZE_X = 12,
43 GRID_SIZE_Y = 16,
44 GRID_SIZE_Z = 20
45};
46
47// Field offsets to implicit kernel argument pointer.
48enum ImplicitArgOffsets {
49 HIDDEN_BLOCK_COUNT_X = 0,
50 HIDDEN_BLOCK_COUNT_Y = 4,
51 HIDDEN_BLOCK_COUNT_Z = 8,
52
53 HIDDEN_GROUP_SIZE_X = 12,
54 HIDDEN_GROUP_SIZE_Y = 14,
55 HIDDEN_GROUP_SIZE_Z = 16,
56
57 HIDDEN_REMAINDER_X = 18,
58 HIDDEN_REMAINDER_Y = 20,
59 HIDDEN_REMAINDER_Z = 22,
60};
61
62class AMDGPULowerKernelAttributes : public ModulePass {
63public:
64 static char ID;
65
66 AMDGPULowerKernelAttributes() : ModulePass(ID) {}
67
68 bool runOnModule(Module &M) override;
69
70 StringRef getPassName() const override { return "AMDGPU Kernel Attributes"; }
71
72 void getAnalysisUsage(AnalysisUsage &AU) const override {
73 AU.setPreservesAll();
74 }
75};
76
77Function *getBasePtrIntrinsic(Module &M, bool IsV5OrAbove) {
78 auto IntrinsicId = IsV5OrAbove ? Intrinsic::amdgcn_implicitarg_ptr
79 : Intrinsic::amdgcn_dispatch_ptr;
80 return Intrinsic::getDeclarationIfExists(&M, IntrinsicId);
81}
82
83} // end anonymous namespace
84
86 uint32_t MaxNumGroups) {
87 if (MaxNumGroups == 0 || MaxNumGroups == std::numeric_limits<uint32_t>::max())
88 return;
89
90 if (!Load->getType()->isIntegerTy(32))
91 return;
92
93 // TODO: If there is existing range metadata, preserve it if it is stricter.
94 MDBuilder MDB(Load->getContext());
95 MDNode *Range = MDB.createRange(APInt(32, 1), APInt(32, MaxNumGroups + 1));
96 Load->setMetadata(LLVMContext::MD_range, Range);
97}
98
99static bool processUse(CallInst *CI, bool IsV5OrAbove) {
100 Function *F = CI->getFunction();
101
102 auto *MD = F->getMetadata("reqd_work_group_size");
103 const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;
104
105 const bool HasUniformWorkGroupSize =
106 F->getFnAttribute("uniform-work-group-size").getValueAsBool();
107
108 SmallVector<unsigned> MaxNumWorkgroups =
109 AMDGPU::getIntegerVecAttribute(*F, "amdgpu-max-num-workgroups",
110 /*Size=*/3, /*DefaultVal=*/0);
111
112 if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize &&
114 Intrinsic::amdgcn_dispatch_ptr) &&
115 none_of(MaxNumWorkgroups, [](unsigned X) { return X != 0; }))
116 return false;
117
118 Value *BlockCounts[3] = {nullptr, nullptr, nullptr};
119 Value *GroupSizes[3] = {nullptr, nullptr, nullptr};
120 Value *Remainders[3] = {nullptr, nullptr, nullptr};
121 Value *GridSizes[3] = {nullptr, nullptr, nullptr};
122
123 const DataLayout &DL = F->getDataLayout();
124
125 // We expect to see several GEP users, casted to the appropriate type and
126 // loaded.
127 for (User *U : CI->users()) {
128 if (!U->hasOneUse())
129 continue;
130
131 int64_t Offset = 0;
132 auto *Load = dyn_cast<LoadInst>(U); // Load from ImplicitArgPtr/DispatchPtr?
133 auto *BCI = dyn_cast<BitCastInst>(U);
134 if (!Load && !BCI) {
136 continue;
137 Load = dyn_cast<LoadInst>(*U->user_begin()); // Load from GEP?
138 BCI = dyn_cast<BitCastInst>(*U->user_begin());
139 }
140
141 if (BCI) {
142 if (!BCI->hasOneUse())
143 continue;
144 Load = dyn_cast<LoadInst>(*BCI->user_begin()); // Load from BCI?
145 }
146
147 if (!Load || !Load->isSimple())
148 continue;
149
150 unsigned LoadSize = DL.getTypeStoreSize(Load->getType());
151
152 // TODO: Handle merged loads.
153 if (IsV5OrAbove) { // Base is ImplicitArgPtr.
154 switch (Offset) {
155 case HIDDEN_BLOCK_COUNT_X:
156 if (LoadSize == 4) {
157 BlockCounts[0] = Load;
158 annotateGridSizeLoadWithRangeMD(Load, MaxNumWorkgroups[0]);
159 }
160 break;
161 case HIDDEN_BLOCK_COUNT_Y:
162 if (LoadSize == 4) {
163 BlockCounts[1] = Load;
164 annotateGridSizeLoadWithRangeMD(Load, MaxNumWorkgroups[1]);
165 }
166 break;
167 case HIDDEN_BLOCK_COUNT_Z:
168 if (LoadSize == 4) {
169 BlockCounts[2] = Load;
170 annotateGridSizeLoadWithRangeMD(Load, MaxNumWorkgroups[2]);
171 }
172 break;
173 case HIDDEN_GROUP_SIZE_X:
174 if (LoadSize == 2)
175 GroupSizes[0] = Load;
176 break;
177 case HIDDEN_GROUP_SIZE_Y:
178 if (LoadSize == 2)
179 GroupSizes[1] = Load;
180 break;
181 case HIDDEN_GROUP_SIZE_Z:
182 if (LoadSize == 2)
183 GroupSizes[2] = Load;
184 break;
185 case HIDDEN_REMAINDER_X:
186 if (LoadSize == 2)
187 Remainders[0] = Load;
188 break;
189 case HIDDEN_REMAINDER_Y:
190 if (LoadSize == 2)
191 Remainders[1] = Load;
192 break;
193 case HIDDEN_REMAINDER_Z:
194 if (LoadSize == 2)
195 Remainders[2] = Load;
196 break;
197 default:
198 break;
199 }
200 } else { // Base is DispatchPtr.
201 switch (Offset) {
202 case WORKGROUP_SIZE_X:
203 if (LoadSize == 2)
204 GroupSizes[0] = Load;
205 break;
206 case WORKGROUP_SIZE_Y:
207 if (LoadSize == 2)
208 GroupSizes[1] = Load;
209 break;
210 case WORKGROUP_SIZE_Z:
211 if (LoadSize == 2)
212 GroupSizes[2] = Load;
213 break;
214 case GRID_SIZE_X:
215 if (LoadSize == 4)
216 GridSizes[0] = Load;
217 break;
218 case GRID_SIZE_Y:
219 if (LoadSize == 4)
220 GridSizes[1] = Load;
221 break;
222 case GRID_SIZE_Z:
223 if (LoadSize == 4)
224 GridSizes[2] = Load;
225 break;
226 default:
227 break;
228 }
229 }
230 }
231
232 bool MadeChange = false;
233 if (IsV5OrAbove && HasUniformWorkGroupSize) {
234 // Under v5 __ockl_get_local_size returns the value computed by the
235 // expression:
236 //
237 // workgroup_id < hidden_block_count ? hidden_group_size :
238 // hidden_remainder
239 //
240 // For functions with the attribute uniform-work-group-size=true. we can
241 // evaluate workgroup_id < hidden_block_count as true, and thus
242 // hidden_group_size is returned for __ockl_get_local_size.
243 for (int I = 0; I < 3; ++I) {
244 Value *BlockCount = BlockCounts[I];
245 if (!BlockCount)
246 continue;
247
248 using namespace llvm::PatternMatch;
249 auto GroupIDIntrin =
253
254 for (User *ICmp : BlockCount->users()) {
255 if (match(ICmp, m_SpecificICmp(ICmpInst::ICMP_ULT, GroupIDIntrin,
256 m_Specific(BlockCount)))) {
257 ICmp->replaceAllUsesWith(llvm::ConstantInt::getTrue(ICmp->getType()));
258 MadeChange = true;
259 }
260 }
261 }
262
263 // All remainders should be 0 with uniform work group size.
264 for (Value *Remainder : Remainders) {
265 if (!Remainder)
266 continue;
267 Remainder->replaceAllUsesWith(
268 Constant::getNullValue(Remainder->getType()));
269 MadeChange = true;
270 }
271 } else if (HasUniformWorkGroupSize) { // Pre-V5.
272 // Pattern match the code used to handle partial workgroup dispatches in the
273 // library implementation of get_local_size, so the entire function can be
274 // constant folded with a known group size.
275 //
276 // uint r = grid_size - group_id * group_size;
277 // get_local_size = (r < group_size) ? r : group_size;
278 //
279 // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
280 // the grid_size is required to be a multiple of group_size). In this case:
281 //
282 // grid_size - (group_id * group_size) < group_size
283 // ->
284 // grid_size < group_size + (group_id * group_size)
285 //
286 // (grid_size / group_size) < 1 + group_id
287 //
288 // grid_size / group_size is at least 1, so we can conclude the select
289 // condition is false (except for group_id == 0, where the select result is
290 // the same).
291 for (int I = 0; I < 3; ++I) {
292 Value *GroupSize = GroupSizes[I];
293 Value *GridSize = GridSizes[I];
294 if (!GroupSize || !GridSize)
295 continue;
296
297 using namespace llvm::PatternMatch;
298 auto GroupIDIntrin =
302
303 for (User *U : GroupSize->users()) {
304 auto *ZextGroupSize = dyn_cast<ZExtInst>(U);
305 if (!ZextGroupSize)
306 continue;
307
308 for (User *UMin : ZextGroupSize->users()) {
309 if (match(UMin, m_UMin(m_Sub(m_Specific(GridSize),
310 m_Mul(GroupIDIntrin,
311 m_Specific(ZextGroupSize))),
312 m_Specific(ZextGroupSize)))) {
313 if (HasReqdWorkGroupSize) {
314 ConstantInt *KnownSize =
315 mdconst::extract<ConstantInt>(MD->getOperand(I));
316 UMin->replaceAllUsesWith(ConstantFoldIntegerCast(
317 KnownSize, UMin->getType(), false, DL));
318 } else {
319 UMin->replaceAllUsesWith(ZextGroupSize);
320 }
321
322 MadeChange = true;
323 }
324 }
325 }
326 }
327 }
328
329 // Upgrade the old method of calculating the block size using the grid size.
330 // We pattern match any case where the implicit argument group size is the
331 // divisor to a dispatch packet grid size read of the same dimension.
332 if (IsV5OrAbove) {
333 for (int I = 0; I < 3; I++) {
334 Value *GroupSize = GroupSizes[I];
335 if (!GroupSize || !GroupSize->getType()->isIntegerTy(16))
336 continue;
337
338 for (User *U : GroupSize->users()) {
340 if (isa<ZExtInst>(Inst) && !Inst->use_empty())
341 Inst = cast<Instruction>(*Inst->user_begin());
342
343 using namespace llvm::PatternMatch;
344 if (!match(
345 Inst,
348 m_SpecificInt(GRID_SIZE_X + I * sizeof(uint32_t))))),
349 m_Value())))
350 continue;
351
352 IRBuilder<> Builder(Inst);
353
354 Value *GEP = Builder.CreateInBoundsGEP(
355 Builder.getInt8Ty(), CI,
356 {ConstantInt::get(Type::getInt64Ty(CI->getContext()),
357 HIDDEN_BLOCK_COUNT_X + I * sizeof(uint32_t))});
358 Instruction *BlockCount = Builder.CreateLoad(Builder.getInt32Ty(), GEP);
359 BlockCount->setMetadata(LLVMContext::MD_invariant_load,
360 MDNode::get(CI->getContext(), {}));
361 BlockCount->setMetadata(LLVMContext::MD_noundef,
362 MDNode::get(CI->getContext(), {}));
363
364 Value *BlockCountExt = Builder.CreateZExt(BlockCount, Inst->getType());
365 Inst->replaceAllUsesWith(BlockCountExt);
366 Inst->eraseFromParent();
367 MadeChange = true;
368 }
369 }
370 }
371
372 // If reqd_work_group_size is set, we can replace work group size with it.
373 if (!HasReqdWorkGroupSize)
374 return MadeChange;
375
376 for (int I = 0; I < 3; I++) {
377 Value *GroupSize = GroupSizes[I];
378 if (!GroupSize)
379 continue;
380
381 ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I));
382 GroupSize->replaceAllUsesWith(
383 ConstantFoldIntegerCast(KnownSize, GroupSize->getType(), false, DL));
384 MadeChange = true;
385 }
386
387 return MadeChange;
388}
389
390// TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
391// TargetPassConfig for subtarget.
392bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {
393 bool MadeChange = false;
394 bool IsV5OrAbove =
396 Function *BasePtr = getBasePtrIntrinsic(M, IsV5OrAbove);
397
398 if (!BasePtr) // ImplicitArgPtr/DispatchPtr not used.
399 return false;
400
401 SmallPtrSet<Instruction *, 4> HandledUses;
402 for (auto *U : BasePtr->users()) {
403 CallInst *CI = cast<CallInst>(U);
404 if (HandledUses.insert(CI).second) {
405 if (processUse(CI, IsV5OrAbove))
406 MadeChange = true;
407 }
408 }
409
410 return MadeChange;
411}
412
413INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE,
414 "AMDGPU Kernel Attributes", false, false)
415INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE,
416 "AMDGPU Kernel Attributes", false, false)
417
418char AMDGPULowerKernelAttributes::ID = 0;
419
421 return new AMDGPULowerKernelAttributes();
422}
423
426 bool IsV5OrAbove =
428 Function *BasePtr = getBasePtrIntrinsic(*F.getParent(), IsV5OrAbove);
429
430 if (!BasePtr) // ImplicitArgPtr/DispatchPtr not used.
431 return PreservedAnalyses::all();
432
433 for (Instruction &I : instructions(F)) {
434 if (CallInst *CI = dyn_cast<CallInst>(&I)) {
435 if (CI->getCalledFunction() == BasePtr)
436 processUse(CI, IsV5OrAbove);
437 }
438 }
439
440 return PreservedAnalyses::all();
441}
static void annotateGridSizeLoadWithRangeMD(LoadInst *Load, uint32_t MaxNumGroups)
static bool processUse(CallInst *CI, bool IsV5OrAbove)
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
This file contains the declarations for the subclasses of Constant, which represent the different fla...
#define DEBUG_TYPE
Hexagon Common GEP
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
static TableGen::Emitter::OptClass< SkeletonEmitter > X("gen-skeleton-class", "Generate example skeleton class")
Class for arbitrary precision integers.
Definition APInt.h:78
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
This class represents a function call, abstracting a target machine's calling convention.
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:701
This is the shared class of boolean and integer constants.
Definition Constants.h:87
static LLVM_ABI ConstantInt * getTrue(LLVMContext &Context)
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2794
LLVM_ABI const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI void setMetadata(unsigned KindID, MDNode *Node)
Set the metadata of the specified kind to the specified node.
An instruction for reading from memory.
LLVM_ABI MDNode * createRange(const APInt &Lo, const APInt &Hi)
Return metadata describing the range [Lo, Hi).
Definition MDBuilder.cpp:96
Metadata node.
Definition Metadata.h:1078
static MDTuple * get(LLVMContext &Context, ArrayRef< Metadata * > MDs)
Definition Metadata.h:1569
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition Pass.h:255
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:240
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
user_iterator user_begin()
Definition Value.h:402
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:553
iterator_range< user_iterator > users()
Definition Value.h:426
bool use_empty() const
Definition Value.h:346
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.cpp:1106
unsigned getAMDHSACodeObjectVersion(const Module &M)
SmallVector< unsigned > getIntegerVecAttribute(const Function &F, StringRef Name, unsigned Size, unsigned DefaultVal)
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
specific_intval< false > m_SpecificInt(const APInt &V)
Match a specific integer value or vector with all elements equal to the value.
match_combine_or< CastInst_match< OpTy, ZExtInst >, OpTy > m_ZExtOrSelf(const OpTy &Op)
bool match(Val *V, const Pattern &P)
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
auto m_GEP(const OperandTypes &...Ops)
Matches GetElementPtrInst.
SpecificCmpClass_match< LHS, RHS, ICmpInst > m_SpecificICmp(CmpPredicate MatchPred, const LHS &L, const RHS &R)
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
BinaryOp_match< LHS, RHS, Instruction::UDiv > m_UDiv(const LHS &L, const RHS &R)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Sub > m_Sub(const LHS &L, const RHS &R)
MaxMin_match< ICmpInst, LHS, RHS, umin_pred_ty > m_UMin(const LHS &L, const RHS &R)
std::enable_if_t< detail::IsValidPointer< X, Y >::value, X * > extract(Y &&MD)
Extract a Value from Metadata.
Definition Metadata.h:667
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
@ Offset
Definition DWP.cpp:532
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL, bool AllowNonInbounds=true)
Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset.
ModulePass * createAMDGPULowerKernelAttributesPass()
bool none_of(R &&Range, UnaryPredicate P)
Provide wrappers to std::none_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1751
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI Constant * ConstantFoldIntegerCast(Constant *C, Type *DestTy, bool IsSigned, const DataLayout &DL)
Constant fold a zext, sext or trunc, depending on IsSigned and whether the DestTy is wider or narrowe...
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)