LLVM 22.0.0git
DXContainerGlobals.cpp
Go to the documentation of this file.
1//===- DXContainerGlobals.cpp - DXContainer global generator pass ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// DXContainerGlobalsPass implementation.
10//
11//===----------------------------------------------------------------------===//
12
13#include "DXILRootSignature.h"
14#include "DXILShaderFlags.h"
15#include "DirectX.h"
18#include "llvm/ADT/StringRef.h"
22#include "llvm/CodeGen/Passes.h"
23#include "llvm/IR/Constants.h"
24#include "llvm/IR/Module.h"
27#include "llvm/Pass.h"
28#include "llvm/Support/MD5.h"
31#include <cstdint>
32#include <optional>
33
34using namespace llvm;
35using namespace llvm::dxil;
36using namespace llvm::mcdxbc;
37
38namespace {
39class DXContainerGlobals : public llvm::ModulePass {
40
41 GlobalVariable *buildContainerGlobal(Module &M, Constant *Content,
42 StringRef Name, StringRef SectionName);
43 GlobalVariable *getFeatureFlags(Module &M);
44 GlobalVariable *computeShaderHash(Module &M);
45 GlobalVariable *buildSignature(Module &M, Signature &Sig, StringRef Name,
46 StringRef SectionName);
47 void addSignature(Module &M, SmallVector<GlobalValue *> &Globals);
48 void addRootSignature(Module &M, SmallVector<GlobalValue *> &Globals);
49 void addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV);
50 void addPipelineStateValidationInfo(Module &M,
52
53public:
54 static char ID; // Pass identification, replacement for typeid
55 DXContainerGlobals() : ModulePass(ID) {}
56
57 StringRef getPassName() const override {
58 return "DXContainer Global Emitter";
59 }
60
61 bool runOnModule(Module &M) override;
62
63 void getAnalysisUsage(AnalysisUsage &AU) const override {
64 AU.setPreservesAll();
65 AU.addRequired<ShaderFlagsAnalysisWrapper>();
66 AU.addRequired<RootSignatureAnalysisWrapper>();
67 AU.addRequired<DXILMetadataAnalysisWrapperPass>();
68 AU.addRequired<DXILResourceTypeWrapperPass>();
69 AU.addRequired<DXILResourceWrapperPass>();
70 }
71};
72
73} // namespace
74
75bool DXContainerGlobals::runOnModule(Module &M) {
77 Globals.push_back(getFeatureFlags(M));
78 Globals.push_back(computeShaderHash(M));
79 addSignature(M, Globals);
80 addRootSignature(M, Globals);
81 addPipelineStateValidationInfo(M, Globals);
82 appendToCompilerUsed(M, Globals);
83 return true;
84}
85
86GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) {
87 uint64_t CombinedFeatureFlags = getAnalysis<ShaderFlagsAnalysisWrapper>()
88 .getShaderFlags()
89 .getCombinedFlags()
90 .getFeatureFlags();
91
92 Constant *FeatureFlagsConstant =
93 ConstantInt::get(M.getContext(), APInt(64, CombinedFeatureFlags));
94 return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0");
95}
96
97GlobalVariable *DXContainerGlobals::computeShaderHash(Module &M) {
98 auto *DXILConstant =
99 cast<ConstantDataArray>(M.getNamedGlobal("dx.dxil")->getInitializer());
100 MD5 Digest;
101 Digest.update(DXILConstant->getRawDataValues());
102 MD5::MD5Result Result = Digest.final();
103
104 dxbc::ShaderHash HashData = {0, {0}};
105 // The Hash's IncludesSource flag gets set whenever the hashed shader includes
106 // debug information.
107 if (!M.debug_compile_units().empty())
108 HashData.Flags = static_cast<uint32_t>(dxbc::HashFlags::IncludesSource);
109
110 memcpy(reinterpret_cast<void *>(&HashData.Digest), Result.data(), 16);
112 HashData.swapBytes();
113 StringRef Data(reinterpret_cast<char *>(&HashData), sizeof(dxbc::ShaderHash));
114
115 Constant *ModuleConstant =
117 return buildContainerGlobal(M, ModuleConstant, "dx.hash", "HASH");
118}
119
120GlobalVariable *DXContainerGlobals::buildContainerGlobal(
121 Module &M, Constant *Content, StringRef Name, StringRef SectionName) {
122 auto *GV = new llvm::GlobalVariable(
123 M, Content->getType(), true, GlobalValue::PrivateLinkage, Content, Name);
124 GV->setSection(SectionName);
125 GV->setAlignment(Align(4));
126 return GV;
127}
128
129GlobalVariable *DXContainerGlobals::buildSignature(Module &M, Signature &Sig,
130 StringRef Name,
131 StringRef SectionName) {
132 SmallString<256> Data;
133 raw_svector_ostream OS(Data);
134 Sig.write(OS);
136 ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
137 return buildContainerGlobal(M, Constant, Name, SectionName);
138}
139
140void DXContainerGlobals::addSignature(Module &M,
142 // FIXME: support graphics shader.
143 // see issue https://github.com/llvm/llvm-project/issues/90504.
144
145 Signature InputSig;
146 Globals.emplace_back(buildSignature(M, InputSig, "dx.isg1", "ISG1"));
147
148 Signature OutputSig;
149 Globals.emplace_back(buildSignature(M, OutputSig, "dx.osg1", "OSG1"));
150}
151
152void DXContainerGlobals::addRootSignature(Module &M,
154
155 dxil::ModuleMetadataInfo &MMI =
156 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
157
158 // Root Signature in Library don't compile to DXContainer.
160 return;
161
162 auto &RSA = getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo();
163 const Function *EntryFunction = nullptr;
164
166 assert(MMI.EntryPropertyVec.size() == 1);
167 EntryFunction = MMI.EntryPropertyVec[0].Entry;
168 }
169
170 const mcdxbc::RootSignatureDesc *RS = RSA.getDescForFunction(EntryFunction);
171 if (!RS)
172 return;
173
174 SmallString<256> Data;
175 raw_svector_ostream OS(Data);
176
177 RS->write(OS);
178
180 ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
181 Globals.emplace_back(buildContainerGlobal(M, Constant, "dx.rts0", "RTS0"));
182}
183
184void DXContainerGlobals::addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV) {
185 const DXILResourceMap &DRM =
186 getAnalysis<DXILResourceWrapperPass>().getResourceMap();
187 DXILResourceTypeMap &DRTM =
188 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
189
190 auto MakeBinding =
191 [](const dxil::ResourceInfo::ResourceBinding &Binding,
193 const dxbc::PSV::ResourceFlags Flags = dxbc::PSV::ResourceFlags()) {
194 dxbc::PSV::v2::ResourceBindInfo BindInfo;
195 BindInfo.Type = Type;
196 BindInfo.LowerBound = Binding.LowerBound;
197 assert(Binding.Size == UINT32_MAX ||
198 (uint64_t)Binding.LowerBound + Binding.Size - 1 <= UINT32_MAX &&
199 "Resource range is too large");
200 BindInfo.UpperBound = (Binding.Size == UINT32_MAX)
201 ? UINT32_MAX
202 : Binding.LowerBound + Binding.Size - 1;
203 BindInfo.Space = Binding.Space;
204 BindInfo.Kind = static_cast<dxbc::PSV::ResourceKind>(Kind);
205 BindInfo.Flags = Flags;
206 return BindInfo;
207 };
208
209 for (const dxil::ResourceInfo &RI : DRM.cbuffers()) {
210 const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding();
211 PSV.Resources.push_back(MakeBinding(Binding, dxbc::PSV::ResourceType::CBV,
212 dxil::ResourceKind::CBuffer));
213 }
214 for (const dxil::ResourceInfo &RI : DRM.samplers()) {
215 const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding();
216 PSV.Resources.push_back(MakeBinding(Binding,
217 dxbc::PSV::ResourceType::Sampler,
218 dxil::ResourceKind::Sampler));
219 }
220 for (const dxil::ResourceInfo &RI : DRM.srvs()) {
221 const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding();
222
223 dxil::ResourceTypeInfo &TypeInfo = DRTM[RI.getHandleTy()];
225 if (TypeInfo.isStruct())
226 ResType = dxbc::PSV::ResourceType::SRVStructured;
227 else if (TypeInfo.isTyped())
228 ResType = dxbc::PSV::ResourceType::SRVTyped;
229 else
230 ResType = dxbc::PSV::ResourceType::SRVRaw;
231
232 PSV.Resources.push_back(
233 MakeBinding(Binding, ResType, TypeInfo.getResourceKind()));
234 }
235 for (const dxil::ResourceInfo &RI : DRM.uavs()) {
236 const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding();
237
238 dxil::ResourceTypeInfo &TypeInfo = DRTM[RI.getHandleTy()];
240 if (RI.hasCounter())
241 ResType = dxbc::PSV::ResourceType::UAVStructuredWithCounter;
242 else if (TypeInfo.isStruct())
243 ResType = dxbc::PSV::ResourceType::UAVStructured;
244 else if (TypeInfo.isTyped())
245 ResType = dxbc::PSV::ResourceType::UAVTyped;
246 else
247 ResType = dxbc::PSV::ResourceType::UAVRaw;
248
249 dxbc::PSV::ResourceFlags Flags;
250 // TODO: Add support for dxbc::PSV::ResourceFlag::UsedByAtomic64, tracking
251 // with https://github.com/llvm/llvm-project/issues/104392
252 Flags.Flags = 0u;
253
254 PSV.Resources.push_back(
255 MakeBinding(Binding, ResType, TypeInfo.getResourceKind(), Flags));
256 }
257}
258
259void DXContainerGlobals::addPipelineStateValidationInfo(
260 Module &M, SmallVector<GlobalValue *> &Globals) {
261 SmallString<256> Data;
262 raw_svector_ostream OS(Data);
263 PSVRuntimeInfo PSV;
265 PSV.BaseData.MaximumWaveLaneCount = std::numeric_limits<uint32_t>::max();
266
267 dxil::ModuleMetadataInfo &MMI =
268 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
269 assert(MMI.EntryPropertyVec.size() == 1 ||
273 static_cast<uint8_t>(MMI.ShaderProfile - Triple::Pixel);
274
275 addResourcesForPSV(M, PSV);
276
277 // Hardcoded values here to unblock loading the shader into D3D.
278 //
279 // TODO: Lots more stuff to do here!
280 //
281 // See issue https://github.com/llvm/llvm-project/issues/96674.
282 switch (MMI.ShaderProfile) {
283 case Triple::Compute:
284 PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX;
285 PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY;
286 PSV.BaseData.NumThreadsZ = MMI.EntryPropertyVec[0].NumThreadsZ;
287 break;
288 default:
289 break;
290 }
291
292 if (MMI.ShaderProfile != Triple::Library &&
294 PSV.EntryName = MMI.EntryPropertyVec[0].Entry->getName();
295
296 PSV.finalize(MMI.ShaderProfile);
297 PSV.write(OS);
299 ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
300 Globals.emplace_back(buildContainerGlobal(M, Constant, "dx.psv0", "PSV0"));
301}
302
303char DXContainerGlobals::ID = 0;
304INITIALIZE_PASS_BEGIN(DXContainerGlobals, "dxil-globals",
305 "DXContainer Global Emitter", false, true)
310INITIALIZE_PASS_END(DXContainerGlobals, "dxil-globals",
311 "DXContainer Global Emitter", false, true)
312
314 return new DXContainerGlobals();
315}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
This file contains the declarations for the subclasses of Constant, which represent the different fla...
DXIL Resource Implicit Binding
Module.h This file contains the declarations for the Module class.
Machine Check Debug Module
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
This file defines the SmallVector class.
This file contains some functions that are useful when dealing with strings.
AnalysisUsage & addRequired()
void setPreservesAll()
Set by analyses that do not transform their input at all.
static LLVM_ABI Constant * getString(LLVMContext &Context, StringRef Initializer, bool AddNull=true)
This method constructs a CDS and initializes it with a text string.
static Constant * get(LLVMContext &Context, ArrayRef< ElementTy > Elts)
get() constructor - Return a constant with array type with an element count and element type matching...
Definition Constants.h:715
iterator_range< iterator > samplers()
iterator_range< iterator > srvs()
iterator_range< iterator > cbuffers()
iterator_range< iterator > uavs()
@ PrivateLinkage
Like Internal, but omit from symbol table.
Definition GlobalValue.h:61
LLVM_ABI void update(ArrayRef< uint8_t > Data)
Updates the hash for the byte stream provided.
Definition MD5.cpp:189
LLVM_ABI void final(MD5Result &Result)
Finishes off the hash and puts the result in result.
Definition MD5.cpp:234
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition Pass.h:255
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
reference emplace_back(ArgTypes &&... Args)
void push_back(const T &Elt)
@ RootSignature
Definition Triple.h:309
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVM_ABI bool isTyped() const
LLVM_ABI bool isStruct() const
dxil::ResourceKind getResourceKind() const
Wrapper pass for the legacy pass manager.
void write(raw_ostream &OS)
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
ResourceKind
The kind of resource for an SRV or UAV resource.
Definition DXILABI.h:36
constexpr bool IsBigEndianHost
This is an optimization pass for GlobalISel generic memory operations.
ArrayRef< CharT > arrayRefFromStringRef(StringRef Input)
Construct a string ref from an array ref of unsigned chars.
ModulePass * createDXContainerGlobalsPass()
Pass for generating DXContainer part globals.
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
FunctionAddr VTableAddr uintptr_t uintptr_t Data
Definition InstrProf.h:189
LLVM_ABI void appendToCompilerUsed(Module &M, ArrayRef< GlobalValue * > Values)
Adds global values to the llvm.compiler.used list.
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:560
Triple::EnvironmentType ShaderProfile
SmallVector< EntryProperties > EntryPropertyVec
dxbc::PSV::v3::RuntimeInfo BaseData
SmallVector< dxbc::PSV::v2::ResourceBindInfo > Resources
void finalize(Triple::EnvironmentType Stage)
void write(raw_ostream &OS, uint32_t Version=std::numeric_limits< uint32_t >::max()) const