LLVM 20.0.0git
DXILOpBuilder.cpp
Go to the documentation of this file.
1//===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file This file contains class to help build DXIL op functions.
10//===----------------------------------------------------------------------===//
11
12#include "DXILOpBuilder.h"
13#include "DXILConstants.h"
14#include "llvm/IR/IRBuilder.h"
15#include "llvm/IR/Module.h"
18#include <optional>
19
20using namespace llvm;
21using namespace llvm::dxil;
22
23constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
24
25namespace {
26enum OverloadKind : uint16_t {
27 UNDEFINED = 0,
28 VOID = 1,
29 HALF = 1 << 1,
30 FLOAT = 1 << 2,
31 DOUBLE = 1 << 3,
32 I1 = 1 << 4,
33 I8 = 1 << 5,
34 I16 = 1 << 6,
35 I32 = 1 << 7,
36 I64 = 1 << 8,
37 UserDefineType = 1 << 9,
38 ObjectType = 1 << 10,
39};
40struct Version {
41 unsigned Major = 0;
42 unsigned Minor = 0;
43};
44
45struct OpOverload {
46 Version DXILVersion;
47 uint16_t ValidTys;
48};
49} // namespace
50
51struct OpStage {
52 Version DXILVersion;
54};
55
57 Version DXILVersion;
59};
60
61static const char *getOverloadTypeName(OverloadKind Kind) {
62 switch (Kind) {
63 case OverloadKind::HALF:
64 return "f16";
65 case OverloadKind::FLOAT:
66 return "f32";
67 case OverloadKind::DOUBLE:
68 return "f64";
69 case OverloadKind::I1:
70 return "i1";
71 case OverloadKind::I8:
72 return "i8";
73 case OverloadKind::I16:
74 return "i16";
75 case OverloadKind::I32:
76 return "i32";
77 case OverloadKind::I64:
78 return "i64";
79 case OverloadKind::VOID:
80 case OverloadKind::UNDEFINED:
81 return "void";
82 case OverloadKind::ObjectType:
83 case OverloadKind::UserDefineType:
84 break;
85 }
86 llvm_unreachable("invalid overload type for name");
87}
88
89static OverloadKind getOverloadKind(Type *Ty) {
90 Type::TypeID T = Ty->getTypeID();
91 switch (T) {
92 case Type::VoidTyID:
93 return OverloadKind::VOID;
94 case Type::HalfTyID:
95 return OverloadKind::HALF;
96 case Type::FloatTyID:
97 return OverloadKind::FLOAT;
99 return OverloadKind::DOUBLE;
100 case Type::IntegerTyID: {
101 IntegerType *ITy = cast<IntegerType>(Ty);
102 unsigned Bits = ITy->getBitWidth();
103 switch (Bits) {
104 case 1:
105 return OverloadKind::I1;
106 case 8:
107 return OverloadKind::I8;
108 case 16:
109 return OverloadKind::I16;
110 case 32:
111 return OverloadKind::I32;
112 case 64:
113 return OverloadKind::I64;
114 default:
115 llvm_unreachable("invalid overload type");
116 return OverloadKind::VOID;
117 }
118 }
120 return OverloadKind::UserDefineType;
121 case Type::StructTyID:
122 return OverloadKind::ObjectType;
123 default:
124 llvm_unreachable("invalid overload type");
125 return OverloadKind::VOID;
126 }
127}
128
129static std::string getTypeName(OverloadKind Kind, Type *Ty) {
130 if (Kind < OverloadKind::UserDefineType) {
131 return getOverloadTypeName(Kind);
132 } else if (Kind == OverloadKind::UserDefineType) {
133 StructType *ST = cast<StructType>(Ty);
134 return ST->getStructName().str();
135 } else if (Kind == OverloadKind::ObjectType) {
136 StructType *ST = cast<StructType>(Ty);
137 return ST->getStructName().str();
138 } else {
139 std::string Str;
141 Ty->print(OS);
142 return OS.str();
143 }
144}
145
146// Static properties.
149 // Offset in DXILOpCodeNameTable.
152 // Offset in DXILOpCodeClassNameTable.
157 int OverloadParamIndex; // parameter index which control the overload.
158 // When < 0, should be only 1 overload type.
159 unsigned NumOfParameters; // Number of parameters include return value.
160 unsigned ParameterTableOffset; // Offset in ParameterTable.
161};
162
163// Include getOpCodeClassName getOpCodeProperty, getOpCodeName and
164// getOpCodeParameterKind which generated by tableGen.
165#define DXIL_OP_OPERATION_TABLE
166#include "DXILOperation.inc"
167#undef DXIL_OP_OPERATION_TABLE
168
169static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
170 const OpCodeProperty &Prop) {
171 if (Kind == OverloadKind::VOID) {
172 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
173 }
174 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
175 getTypeName(Kind, Ty))
176 .str();
177}
178
179static std::string constructOverloadTypeName(OverloadKind Kind,
180 StringRef TypeName) {
181 if (Kind == OverloadKind::VOID)
182 return TypeName.str();
183
184 assert(Kind < OverloadKind::UserDefineType && "invalid overload kind");
185 return (Twine(TypeName) + getOverloadTypeName(Kind)).str();
186}
187
189 ArrayRef<Type *> EltTys,
190 LLVMContext &Ctx) {
192 if (ST)
193 return ST;
194
195 return StructType::create(Ctx, EltTys, Name);
196}
197
198static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
199 OverloadKind Kind = getOverloadKind(OverloadTy);
200 std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
201 Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
202 Type::getInt32Ty(Ctx)};
203 return getOrCreateStructType(TypeName, FieldTypes, Ctx);
204}
205
207 return getOrCreateStructType("dx.types.Handle", PointerType::getUnqual(Ctx),
208 Ctx);
209}
210
212 Type *OverloadTy) {
213 switch (Kind) {
214 case ParameterKind::Void:
215 return Type::getVoidTy(Ctx);
216 case ParameterKind::Half:
217 return Type::getHalfTy(Ctx);
218 case ParameterKind::Float:
219 return Type::getFloatTy(Ctx);
220 case ParameterKind::Double:
221 return Type::getDoubleTy(Ctx);
222 case ParameterKind::I1:
223 return Type::getInt1Ty(Ctx);
224 case ParameterKind::I8:
225 return Type::getInt8Ty(Ctx);
226 case ParameterKind::I16:
227 return Type::getInt16Ty(Ctx);
228 case ParameterKind::I32:
229 return Type::getInt32Ty(Ctx);
230 case ParameterKind::I64:
231 return Type::getInt64Ty(Ctx);
232 case ParameterKind::Overload:
233 return OverloadTy;
234 case ParameterKind::ResourceRet:
235 return getResRetType(OverloadTy, Ctx);
236 case ParameterKind::DXILHandle:
237 return getHandleType(Ctx);
238 default:
239 break;
240 }
241 llvm_unreachable("Invalid parameter kind");
242 return nullptr;
243}
244
245static ShaderKind getShaderKindEnum(Triple::EnvironmentType EnvType) {
246 switch (EnvType) {
247 case Triple::Pixel:
248 return ShaderKind::pixel;
249 case Triple::Vertex:
250 return ShaderKind::vertex;
251 case Triple::Geometry:
252 return ShaderKind::geometry;
253 case Triple::Hull:
254 return ShaderKind::hull;
255 case Triple::Domain:
256 return ShaderKind::domain;
257 case Triple::Compute:
258 return ShaderKind::compute;
259 case Triple::Library:
260 return ShaderKind::library;
262 return ShaderKind::raygeneration;
264 return ShaderKind::intersection;
265 case Triple::AnyHit:
266 return ShaderKind::anyhit;
268 return ShaderKind::closesthit;
269 case Triple::Miss:
270 return ShaderKind::miss;
271 case Triple::Callable:
272 return ShaderKind::callable;
273 case Triple::Mesh:
274 return ShaderKind::mesh;
276 return ShaderKind::amplification;
277 default:
278 break;
279 }
281 "Shader Kind Not Found - Invalid DXIL Environment Specified");
282}
283
284/// Construct DXIL function type. This is the type of a function with
285/// the following prototype
286/// OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>)
287/// <param-types> are constructed from types in Prop.
289 LLVMContext &Context,
290 Type *OverloadTy) {
291 SmallVector<Type *> ArgTys;
292
293 const ParameterKind *ParamKinds = getOpCodeParameterKind(*Prop);
294
295 assert(Prop->NumOfParameters && "No return type?");
296 // Add return type of the function
297 Type *ReturnTy = getTypeFromParameterKind(ParamKinds[0], Context, OverloadTy);
298
299 // Add DXIL Opcode value type viz., Int32 as first argument
300 ArgTys.emplace_back(Type::getInt32Ty(Context));
301
302 // Add DXIL Operation parameter types as specified in DXIL properties
303 for (unsigned I = 1; I < Prop->NumOfParameters; ++I) {
304 ParameterKind Kind = ParamKinds[I];
305 ArgTys.emplace_back(getTypeFromParameterKind(Kind, Context, OverloadTy));
306 }
307 return FunctionType::get(ReturnTy, ArgTys, /*isVarArg=*/false);
308}
309
310/// Get index of the property from PropList valid for the most recent
311/// DXIL version not greater than DXILVer.
312/// PropList is expected to be sorted in ascending order of DXIL version.
313template <typename T>
314static std::optional<size_t> getPropIndex(ArrayRef<T> PropList,
315 const VersionTuple DXILVer) {
316 size_t Index = PropList.size() - 1;
317 for (auto Iter = PropList.rbegin(); Iter != PropList.rend();
318 Iter++, Index--) {
319 const T &Prop = *Iter;
320 if (VersionTuple(Prop.DXILVersion.Major, Prop.DXILVersion.Minor) <=
321 DXILVer) {
322 return Index;
323 }
324 }
325 return std::nullopt;
326}
327
328namespace llvm {
329namespace dxil {
330
331// No extra checks on TargetTriple need be performed to verify that the
332// Triple is well-formed or that the target is supported since these checks
333// would have been done at the time the module M is constructed in the earlier
334// stages of compilation.
336 Triple TT(Triple(M.getTargetTriple()));
337 DXILVersion = TT.getDXILVersion();
338 ShaderStage = TT.getEnvironment();
339 // Ensure Environment type is known
340 if (ShaderStage == Triple::UnknownEnvironment) {
342 Twine(DXILVersion.getAsString()) +
343 ": Unknown Compilation Target Shader Stage specified ",
344 /*gen_crash_diag*/ false);
345 }
346}
347
349 return make_error<StringError>(
350 Twine("Cannot create ") + getOpCodeName(OpCode) + " operation: " + Msg,
352}
353
356 Type *RetTy) {
357 const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
358
359 Type *OverloadTy = nullptr;
360 if (Prop->OverloadParamIndex == 0) {
361 if (!RetTy)
362 return makeOpError(OpCode, "Op overloaded on unknown return type");
363 OverloadTy = RetTy;
364 } else if (Prop->OverloadParamIndex > 0) {
365 // The index counts including the return type
366 unsigned ArgIndex = Prop->OverloadParamIndex - 1;
367 if (static_cast<unsigned>(ArgIndex) >= Args.size())
368 return makeOpError(OpCode, "Wrong number of arguments");
369 OverloadTy = Args[ArgIndex]->getType();
370 }
371 FunctionType *DXILOpFT =
372 getDXILOpFunctionType(Prop, M.getContext(), OverloadTy);
373
374 std::optional<size_t> OlIndexOrErr =
375 getPropIndex(ArrayRef(Prop->Overloads), DXILVersion);
376 if (!OlIndexOrErr.has_value())
377 return makeOpError(OpCode, Twine("No valid overloads for DXIL version ") +
378 DXILVersion.getAsString());
379
380 uint16_t ValidTyMask = Prop->Overloads[*OlIndexOrErr].ValidTys;
381
382 // If we don't have an overload type, use the function's return type. This is
383 // a bit of a hack, but it's necessary to get the type suffix on unoverloaded
384 // DXIL ops correct, like `dx.op.threadId.i32`.
385 OverloadKind Kind =
386 getOverloadKind(OverloadTy ? OverloadTy : DXILOpFT->getReturnType());
387
388 // Check if the operation supports overload types and OverloadTy is valid
389 // per the specified types for the operation
390 if ((ValidTyMask != OverloadKind::UNDEFINED) &&
391 (ValidTyMask & (uint16_t)Kind) == 0)
392 return makeOpError(OpCode, "Invalid overload type");
393
394 // Perform necessary checks to ensure Opcode is valid in the targeted shader
395 // kind
396 std::optional<size_t> StIndexOrErr =
397 getPropIndex(ArrayRef(Prop->Stages), DXILVersion);
398 if (!StIndexOrErr.has_value())
399 return makeOpError(OpCode, Twine("No valid stage for DXIL version ") +
400 DXILVersion.getAsString());
401
402 uint16_t ValidShaderKindMask = Prop->Stages[*StIndexOrErr].ValidStages;
403
404 // Ensure valid shader stage properties are specified
405 if (ValidShaderKindMask == ShaderKind::removed)
406 return makeOpError(OpCode, "Operation has been removed");
407
408 // Shader stage need not be validated since getShaderKindEnum() fails
409 // for unknown shader stage.
410
411 // Verify the target shader stage is valid for the DXIL operation
412 ShaderKind ModuleStagekind = getShaderKindEnum(ShaderStage);
413 if (!(ValidShaderKindMask & ModuleStagekind))
414 return makeOpError(OpCode, "Invalid stage");
415
416 std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop);
417 FunctionCallee DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
418
419 // We need to inject the opcode as the first argument.
422 OpArgs.append(Args.begin(), Args.end());
423
424 return B.CreateCall(DXILFn, OpArgs);
425}
426
428 Type *RetTy) {
430 if (Error E = Result.takeError())
431 llvm_unreachable("Invalid arguments for operation");
432 return *Result;
433}
434
436 return ::getOpCodeName(DXILOp);
437}
438} // namespace dxil
439} // namespace llvm
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static ShaderKind getShaderKindEnum(Triple::EnvironmentType EnvType)
static std::optional< size_t > getPropIndex(ArrayRef< T > PropList, const VersionTuple DXILVer)
Get index of the property from PropList valid for the most recent DXIL version not greater than DXILV...
static StructType * getResRetType(Type *OverloadTy, LLVMContext &Ctx)
static const char * getOverloadTypeName(OverloadKind Kind)
static OverloadKind getOverloadKind(Type *Ty)
static FunctionType * getDXILOpFunctionType(const OpCodeProperty *Prop, LLVMContext &Context, Type *OverloadTy)
Construct DXIL function type.
static Type * getTypeFromParameterKind(ParameterKind Kind, LLVMContext &Ctx, Type *OverloadTy)
static StructType * getOrCreateStructType(StringRef Name, ArrayRef< Type * > EltTys, LLVMContext &Ctx)
static StructType * getHandleType(LLVMContext &Ctx)
static std::string constructOverloadName(OverloadKind Kind, Type *Ty, const OpCodeProperty &Prop)
constexpr StringLiteral DXILOpNamePrefix
static std::string constructOverloadTypeName(OverloadKind Kind, StringRef TypeName)
return RetTy
std::string Name
#define I(x, y, z)
Definition: MD5.cpp:58
Module.h This file contains the declarations for the Module class.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
raw_pwrite_stream & OS
@ UNDEFINED
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
reverse_iterator rend() const
Definition: ArrayRef.h:157
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:165
reverse_iterator rbegin() const
Definition: ArrayRef.h:156
This class represents a function call, abstracting a target machine's calling convention.
Lightweight error class with error context and mandatory checking.
Definition: Error.h:160
Tagged union holding either a T or a Error.
Definition: Error.h:481
A handy container for a FunctionType+Callee-pointer pair, which can be passed around as a single enti...
Definition: DerivedTypes.h:168
Class to represent function types.
Definition: DerivedTypes.h:103
Type * getReturnType() const
Definition: DerivedTypes.h:124
Common base class shared among various IRBuilders.
Definition: IRBuilder.h:91
ConstantInt * getInt32(uint32_t C)
Get a constant 32-bit value.
Definition: IRBuilder.h:483
CallInst * CreateCall(FunctionType *FTy, Value *Callee, ArrayRef< Value * > Args=std::nullopt, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition: IRBuilder.h:2432
Class to represent integer types.
Definition: DerivedTypes.h:40
unsigned getBitWidth() const
Get the number of bits in this IntegerType.
Definition: DerivedTypes.h:72
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:67
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
LLVMContext & getContext() const
Get the global data context.
Definition: Module.h:299
FunctionCallee getOrInsertFunction(StringRef Name, FunctionType *T, AttributeList AttributeList)
Look up the specified function in the module symbol table.
Definition: Module.cpp:169
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:951
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:697
void push_back(const T &Elt)
Definition: SmallVector.h:427
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1210
A wrapper around a string literal that serves as a proxy for constructing global tables of StringRefs...
Definition: StringRef.h:838
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
Class to represent struct types.
Definition: DerivedTypes.h:216
static StructType * getTypeByName(LLVMContext &C, StringRef Name)
Return the type with the specified name, or null if there is none by that name.
Definition: Type.cpp:620
static StructType * create(LLVMContext &Context, StringRef Name)
This creates an identified struct.
Definition: Type.cpp:501
Triple - Helper class for working with autoconf configuration names.
Definition: Triple.h:44
@ RayGeneration
Definition: Triple.h:282
@ UnknownEnvironment
Definition: Triple.h:243
@ ClosestHit
Definition: Triple.h:285
@ Amplification
Definition: Triple.h:289
@ Intersection
Definition: Triple.h:283
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
static Type * getHalfTy(LLVMContext &C)
static Type * getDoubleTy(LLVMContext &C)
static IntegerType * getInt1Ty(LLVMContext &C)
TypeID
Definitions of all of the base types for the Type system.
Definition: Type.h:54
@ HalfTyID
16-bit floating point type
Definition: Type.h:56
@ VoidTyID
type with no size
Definition: Type.h:63
@ FloatTyID
32-bit floating point type
Definition: Type.h:58
@ StructTyID
Structures.
Definition: Type.h:73
@ IntegerTyID
Arbitrary bit width integers.
Definition: Type.h:70
@ DoubleTyID
64-bit floating point type
Definition: Type.h:59
@ PointerTyID
Pointers.
Definition: Type.h:72
void print(raw_ostream &O, bool IsForDebug=false, bool NoDetails=false) const
Print the current type.
static Type * getVoidTy(LLVMContext &C)
static IntegerType * getInt16Ty(LLVMContext &C)
static IntegerType * getInt8Ty(LLVMContext &C)
static IntegerType * getInt32Ty(LLVMContext &C)
static IntegerType * getInt64Ty(LLVMContext &C)
static Type * getFloatTy(LLVMContext &C)
TypeID getTypeID() const
Return the type id for the type.
Definition: Type.h:136
Represents a version number in the form major[.minor[.subminor[.build]]].
Definition: VersionTuple.h:29
std::string getAsString() const
Retrieve a string representation of the version number.
DXILOpBuilder(Module &M, IRBuilderBase &B)
CallInst * createOp(dxil::OpCode Op, ArrayRef< Value * > &Args, Type *RetTy=nullptr)
Create a call instruction for the given DXIL op.
Expected< CallInst * > tryCreateOp(dxil::OpCode Op, ArrayRef< Value * > Args, Type *RetTy=nullptr)
Try to create a call instruction for the given DXIL op.
static const char * getOpCodeName(dxil::OpCode DXILOp)
Return the name of the given opcode.
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:661
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
static Error makeOpError(dxil::OpCode OpCode, Twine Msg)
ParameterKind
Definition: DXILABI.h:25
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
std::error_code inconvertibleErrorCode()
The value returned by this function can be returned from convertToErrorCode for Error values where no...
Definition: Error.cpp:98
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Definition: Error.cpp:167
constexpr std::underlying_type_t< Enum > to_underlying(Enum E)
Returns underlying integer value of an enum.
StringRef getTypeName()
We provide a function which tries to compute the (demangled) name of a type statically.
Definition: TypeName.h:27
Version DXILVersion
uint32_t ValidAttrs
llvm::SmallVector< OpOverload > Overloads
llvm::SmallVector< OpAttribute > Attributes
dxil::OpCodeClass OpCodeClass
unsigned OpCodeNameOffset
unsigned OpCodeClassNameOffset
unsigned ParameterTableOffset
unsigned NumOfParameters
llvm::SmallVector< OpStage > Stages
dxil::OpCode OpCode
uint32_t ValidStages
Version DXILVersion