LLVM 19.0.0git
SPIRVLegalizerInfo.cpp
Go to the documentation of this file.
1//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the targeting of the Machinelegalizer class for SPIR-V.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVLegalizerInfo.h"
14#include "SPIRV.h"
15#include "SPIRVGlobalRegistry.h"
16#include "SPIRVSubtarget.h"
22
23using namespace llvm;
24using namespace llvm::LegalizeActions;
25using namespace llvm::LegalityPredicates;
26
27static const std::set<unsigned> TypeFoldingSupportingOpcs = {
28 TargetOpcode::G_ADD,
29 TargetOpcode::G_FADD,
30 TargetOpcode::G_SUB,
31 TargetOpcode::G_FSUB,
32 TargetOpcode::G_MUL,
33 TargetOpcode::G_FMUL,
34 TargetOpcode::G_SDIV,
35 TargetOpcode::G_UDIV,
36 TargetOpcode::G_FDIV,
37 TargetOpcode::G_SREM,
38 TargetOpcode::G_UREM,
39 TargetOpcode::G_FREM,
40 TargetOpcode::G_FNEG,
41 TargetOpcode::G_CONSTANT,
42 TargetOpcode::G_FCONSTANT,
43 TargetOpcode::G_AND,
44 TargetOpcode::G_OR,
45 TargetOpcode::G_XOR,
46 TargetOpcode::G_SHL,
47 TargetOpcode::G_ASHR,
48 TargetOpcode::G_LSHR,
49 TargetOpcode::G_SELECT,
50 TargetOpcode::G_EXTRACT_VECTOR_ELT,
51};
52
53bool isTypeFoldingSupported(unsigned Opcode) {
54 return TypeFoldingSupportingOpcs.count(Opcode) > 0;
55}
56
58 using namespace TargetOpcode;
59
60 this->ST = &ST;
61 GR = ST.getSPIRVGlobalRegistry();
62
63 const LLT s1 = LLT::scalar(1);
64 const LLT s8 = LLT::scalar(8);
65 const LLT s16 = LLT::scalar(16);
66 const LLT s32 = LLT::scalar(32);
67 const LLT s64 = LLT::scalar(64);
68
69 const LLT v16s64 = LLT::fixed_vector(16, 64);
70 const LLT v16s32 = LLT::fixed_vector(16, 32);
71 const LLT v16s16 = LLT::fixed_vector(16, 16);
72 const LLT v16s8 = LLT::fixed_vector(16, 8);
73 const LLT v16s1 = LLT::fixed_vector(16, 1);
74
75 const LLT v8s64 = LLT::fixed_vector(8, 64);
76 const LLT v8s32 = LLT::fixed_vector(8, 32);
77 const LLT v8s16 = LLT::fixed_vector(8, 16);
78 const LLT v8s8 = LLT::fixed_vector(8, 8);
79 const LLT v8s1 = LLT::fixed_vector(8, 1);
80
81 const LLT v4s64 = LLT::fixed_vector(4, 64);
82 const LLT v4s32 = LLT::fixed_vector(4, 32);
83 const LLT v4s16 = LLT::fixed_vector(4, 16);
84 const LLT v4s8 = LLT::fixed_vector(4, 8);
85 const LLT v4s1 = LLT::fixed_vector(4, 1);
86
87 const LLT v3s64 = LLT::fixed_vector(3, 64);
88 const LLT v3s32 = LLT::fixed_vector(3, 32);
89 const LLT v3s16 = LLT::fixed_vector(3, 16);
90 const LLT v3s8 = LLT::fixed_vector(3, 8);
91 const LLT v3s1 = LLT::fixed_vector(3, 1);
92
93 const LLT v2s64 = LLT::fixed_vector(2, 64);
94 const LLT v2s32 = LLT::fixed_vector(2, 32);
95 const LLT v2s16 = LLT::fixed_vector(2, 16);
96 const LLT v2s8 = LLT::fixed_vector(2, 8);
97 const LLT v2s1 = LLT::fixed_vector(2, 1);
98
99 const unsigned PSize = ST.getPointerSize();
100 const LLT p0 = LLT::pointer(0, PSize); // Function
101 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
102 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
103 const LLT p3 = LLT::pointer(3, PSize); // Workgroup
104 const LLT p4 = LLT::pointer(4, PSize); // Generic
105 const LLT p5 = LLT::pointer(5, PSize); // Input
106
107 // TODO: remove copy-pasting here by using concatenation in some way.
108 auto allPtrsScalarsAndVectors = {
109 p0, p1, p2, p3, p4, p5, s1, s8, s16,
110 s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
111 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1,
112 v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
113
114 auto allScalarsAndVectors = {
115 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
116 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
117 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
118
119 auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,
120 v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,
121 v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,
122 v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
123
124 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
125
126 auto allIntScalars = {s8, s16, s32, s64};
127
128 auto allFloatScalars = {s16, s32, s64};
129
130 auto allFloatScalarsAndVectors = {
131 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
132 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
133
134 auto allFloatAndIntScalars = allIntScalars;
135
136 auto allPtrs = {p0, p1, p2, p3, p4, p5};
137 auto allWritablePtrs = {p0, p1, p3, p4};
138
139 for (auto Opc : TypeFoldingSupportingOpcs)
141
143
144 // TODO: add proper rules for vectors legalization.
145 getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
146
147 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
148 .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
149
151 all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars)));
152
153 getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
154 .legalForCartesianProduct(allPtrs, allPtrs);
155
156 getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
157
158 getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors);
159
160 getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
161
162 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
163 .legalForCartesianProduct(allIntScalarsAndVectors,
164 allFloatScalarsAndVectors);
165
166 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
167 .legalForCartesianProduct(allFloatScalarsAndVectors,
168 allScalarsAndVectors);
169
170 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
171 .legalFor(allIntScalarsAndVectors);
172
174 allIntScalarsAndVectors, allIntScalarsAndVectors);
175
176 getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
177
179 typeInSet(0, allPtrsScalarsAndVectors),
180 typeInSet(1, allPtrsScalarsAndVectors),
181 LegalityPredicate(([=](const LegalityQuery &Query) {
182 return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits();
183 }))));
184
186
188 .legalForCartesianProduct(allPtrs, allIntScalars);
190 .legalForCartesianProduct(allIntScalars, allPtrs);
192 allPtrs, allIntScalars);
193
194 // ST.canDirectlyComparePointers() for pointer args is supported in
195 // legalizeCustom().
197 all(typeInSet(0, allBoolScalarsAndVectors),
198 typeInSet(1, allPtrsScalarsAndVectors)));
199
201 all(typeInSet(0, allBoolScalarsAndVectors),
202 typeInSet(1, allFloatScalarsAndVectors)));
203
204 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
205 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
206 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
207 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
208 .legalForCartesianProduct(allIntScalars, allWritablePtrs);
209
211 {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
212 .legalForCartesianProduct(allFloatScalars, allWritablePtrs);
213
214 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
215 .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
216
217 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
218 // TODO: add proper legalization rules.
219 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
220
221 getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
222 .alwaysLegal();
223
224 // Extensions.
225 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
226 .legalForCartesianProduct(allScalarsAndVectors);
227
228 // FP conversions.
229 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
230 .legalForCartesianProduct(allFloatScalarsAndVectors);
231
232 // Pointer-handling.
233 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
234
235 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
236 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
237
238 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
239 // tighten these requirements. Many of these math functions are only legal on
240 // specific bitwidths, so they are not selectable for
241 // allFloatScalarsAndVectors.
243 G_FEXP,
244 G_FEXP2,
245 G_FLOG,
246 G_FLOG2,
247 G_FLOG10,
248 G_FABS,
249 G_FMINNUM,
250 G_FMAXNUM,
251 G_FCEIL,
252 G_FCOS,
253 G_FSIN,
254 G_FSQRT,
255 G_FFLOOR,
256 G_FRINT,
257 G_FNEARBYINT,
258 G_INTRINSIC_ROUND,
259 G_INTRINSIC_TRUNC,
260 G_FMINIMUM,
261 G_FMAXIMUM,
262 G_INTRINSIC_ROUNDEVEN})
263 .legalFor(allFloatScalarsAndVectors);
264
265 getActionDefinitionsBuilder(G_FCOPYSIGN)
266 .legalForCartesianProduct(allFloatScalarsAndVectors,
267 allFloatScalarsAndVectors);
268
270 allFloatScalarsAndVectors, allIntScalarsAndVectors);
271
272 if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
274 {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
275 .legalForCartesianProduct(allIntScalarsAndVectors,
276 allIntScalarsAndVectors);
277
278 // Struct return types become a single scalar, so cannot easily legalize.
279 getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
280 }
281
283 verify(*ST.getInstrInfo());
284}
285
286static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
287 LegalizerHelper &Helper,
290 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
291 GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
292 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
293 .addDef(ConvReg)
294 .addUse(Reg);
295 return ConvReg;
296}
297
300 LostDebugLocObserver &LocObserver) const {
301 auto Opc = MI.getOpcode();
302 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
303 if (!isTypeFoldingSupported(Opc)) {
304 assert(Opc == TargetOpcode::G_ICMP);
305 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
306 auto &Op0 = MI.getOperand(2);
307 auto &Op1 = MI.getOperand(3);
308 Register Reg0 = Op0.getReg();
309 Register Reg1 = Op1.getReg();
311 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
312 if ((!ST->canDirectlyComparePointers() ||
314 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
315 LLT ConvT = LLT::scalar(ST->getPointerSize());
316 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
317 ST->getPointerSize());
318 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
319 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
320 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
321 }
322 return true;
323 }
324 // TODO: implement legalization for other opcodes.
325 return true;
326}
unsigned const MachineRegisterInfo * MRI
IRTranslator LLVM IR MI
This file declares the MachineIRBuilder class.
ppc ctr loops verify
const SmallVectorImpl< MachineOperand > & Cond
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
bool isTypeFoldingSupported(unsigned Opcode)
static const std::set< unsigned > TypeFoldingSupportingOpcs
static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType, LegalizerHelper &Helper, MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR)
bool isTypeFoldingSupported(unsigned Opcode)
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:780
@ ICMP_EQ
equal
Definition: InstrTypes.h:801
@ ICMP_NE
not equal
Definition: InstrTypes.h:802
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:278
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
Definition: LowLevelType.h:42
static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits)
Get a low-level pointer in the given address space.
Definition: LowLevelType.h:49
static constexpr LLT fixed_vector(unsigned NumElements, unsigned ScalarSizeInBits)
Get a low-level fixed-width vector of some number of elements and element width.
Definition: LowLevelType.h:92
void computeTables()
Compute any ancillary tables needed to quickly decide how an operation should be handled.
LegalizeRuleSet & legalFor(std::initializer_list< LLT > Types)
The instruction is legal when type index 0 is any type in the given list.
LegalizeRuleSet & lower()
The instruction is lowered.
LegalizeRuleSet & custom()
Unconditionally custom lower.
LegalizeRuleSet & alwaysLegal()
LegalizeRuleSet & customIf(LegalityPredicate Predicate)
LegalizeRuleSet & legalForCartesianProduct(std::initializer_list< LLT > Types)
The instruction is legal when type indexes 0 and 1 are both in the given list.
LegalizeRuleSet & legalIf(LegalityPredicate Predicate)
The instruction is legal if predicate is true.
MachineIRBuilder & MIRBuilder
Expose MIRBuilder so clients can set their own RecordInsertInstruction functions.
LegalizeRuleSet & getActionDefinitionsBuilder(unsigned Opcode)
Get the action definition builder for the given opcode.
const LegacyLegalizerInfo & getLegacyLegalizerInfo() const
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
MachineFunction & getMF()
Getter for the function we currently build.
const MachineInstrBuilder & addUse(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register use operand.
const MachineInstrBuilder & addDef(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
Definition: MachineInstr.h:68
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
SPIRVType * getSPIRVTypeForVReg(Register VReg) const
SPIRVType * getOrCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AQ=SPIRV::AccessQualifier::ReadWrite, bool EmitIR=true)
void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg, MachineFunction &MF)
SPIRVLegalizerInfo(const SPIRVSubtarget &ST)
bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const override
Called for instructions with the Custom LegalizationAction.
unsigned getPointerSize() const
bool canDirectlyComparePointers() const
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
LegalityPredicate typeInSet(unsigned TypeIdx, std::initializer_list< LLT > TypesInit)
True iff the given type index is one of the specified types.
Predicate all(Predicate P0, Predicate P1)
True iff P0 and P1 are true.
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
std::function< bool(const LegalityQuery &)> LegalityPredicate
The LegalityQuery object bundles together all the information that's needed to decide whether a given...
ArrayRef< LLT > Types