LLVM 19.0.0git
SPIRVLegalizerInfo.cpp
Go to the documentation of this file.
1//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the targeting of the Machinelegalizer class for SPIR-V.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVLegalizerInfo.h"
14#include "SPIRV.h"
15#include "SPIRVGlobalRegistry.h"
16#include "SPIRVSubtarget.h"
22
23using namespace llvm;
24using namespace llvm::LegalizeActions;
25using namespace llvm::LegalityPredicates;
26
27static const std::set<unsigned> TypeFoldingSupportingOpcs = {
28 TargetOpcode::G_ADD,
29 TargetOpcode::G_FADD,
30 TargetOpcode::G_SUB,
31 TargetOpcode::G_FSUB,
32 TargetOpcode::G_MUL,
33 TargetOpcode::G_FMUL,
34 TargetOpcode::G_SDIV,
35 TargetOpcode::G_UDIV,
36 TargetOpcode::G_FDIV,
37 TargetOpcode::G_SREM,
38 TargetOpcode::G_UREM,
39 TargetOpcode::G_FREM,
40 TargetOpcode::G_FNEG,
41 TargetOpcode::G_CONSTANT,
42 TargetOpcode::G_FCONSTANT,
43 TargetOpcode::G_AND,
44 TargetOpcode::G_OR,
45 TargetOpcode::G_XOR,
46 TargetOpcode::G_SHL,
47 TargetOpcode::G_ASHR,
48 TargetOpcode::G_LSHR,
49 TargetOpcode::G_SELECT,
50 TargetOpcode::G_EXTRACT_VECTOR_ELT,
51};
52
53bool isTypeFoldingSupported(unsigned Opcode) {
54 return TypeFoldingSupportingOpcs.count(Opcode) > 0;
55}
56
58 using namespace TargetOpcode;
59
60 this->ST = &ST;
61 GR = ST.getSPIRVGlobalRegistry();
62
63 const LLT s1 = LLT::scalar(1);
64 const LLT s8 = LLT::scalar(8);
65 const LLT s16 = LLT::scalar(16);
66 const LLT s32 = LLT::scalar(32);
67 const LLT s64 = LLT::scalar(64);
68
69 const LLT v16s64 = LLT::fixed_vector(16, 64);
70 const LLT v16s32 = LLT::fixed_vector(16, 32);
71 const LLT v16s16 = LLT::fixed_vector(16, 16);
72 const LLT v16s8 = LLT::fixed_vector(16, 8);
73 const LLT v16s1 = LLT::fixed_vector(16, 1);
74
75 const LLT v8s64 = LLT::fixed_vector(8, 64);
76 const LLT v8s32 = LLT::fixed_vector(8, 32);
77 const LLT v8s16 = LLT::fixed_vector(8, 16);
78 const LLT v8s8 = LLT::fixed_vector(8, 8);
79 const LLT v8s1 = LLT::fixed_vector(8, 1);
80
81 const LLT v4s64 = LLT::fixed_vector(4, 64);
82 const LLT v4s32 = LLT::fixed_vector(4, 32);
83 const LLT v4s16 = LLT::fixed_vector(4, 16);
84 const LLT v4s8 = LLT::fixed_vector(4, 8);
85 const LLT v4s1 = LLT::fixed_vector(4, 1);
86
87 const LLT v3s64 = LLT::fixed_vector(3, 64);
88 const LLT v3s32 = LLT::fixed_vector(3, 32);
89 const LLT v3s16 = LLT::fixed_vector(3, 16);
90 const LLT v3s8 = LLT::fixed_vector(3, 8);
91 const LLT v3s1 = LLT::fixed_vector(3, 1);
92
93 const LLT v2s64 = LLT::fixed_vector(2, 64);
94 const LLT v2s32 = LLT::fixed_vector(2, 32);
95 const LLT v2s16 = LLT::fixed_vector(2, 16);
96 const LLT v2s8 = LLT::fixed_vector(2, 8);
97 const LLT v2s1 = LLT::fixed_vector(2, 1);
98
99 const unsigned PSize = ST.getPointerSize();
100 const LLT p0 = LLT::pointer(0, PSize); // Function
101 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
102 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
103 const LLT p3 = LLT::pointer(3, PSize); // Workgroup
104 const LLT p4 = LLT::pointer(4, PSize); // Generic
105 const LLT p5 =
106 LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
107 const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
108
109 // TODO: remove copy-pasting here by using concatenation in some way.
110 auto allPtrsScalarsAndVectors = {
111 p0, p1, p2, p3, p4, p5, p6, s1, s8, s16,
112 s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16,
113 v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16,
114 v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
115
116 auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
117 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32,
118 v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
119 v16s8, v16s16, v16s32, v16s64};
120
121 auto allScalarsAndVectors = {
122 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
123 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
124 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
125
126 auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,
127 v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,
128 v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,
129 v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
130
131 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
132
133 auto allIntScalars = {s8, s16, s32, s64};
134
135 auto allFloatScalars = {s16, s32, s64};
136
137 auto allFloatScalarsAndVectors = {
138 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
139 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
140
141 auto allFloatAndIntScalars = allIntScalars;
142
143 auto allPtrs = {p0, p1, p2, p3, p4, p5, p6};
144 auto allWritablePtrs = {p0, p1, p3, p4, p5, p6};
145
146 for (auto Opc : TypeFoldingSupportingOpcs)
148
150
151 // TODO: add proper rules for vectors legalization.
153 {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
154 .alwaysLegal();
155
156 // Vector Reduction Operations
158 {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
159 G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
160 G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
161 G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
162 .legalFor(allVectors)
163 .scalarize(1)
164 .lower();
165
166 getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
167 .scalarize(2)
168 .lower();
169
170 // Merge/Unmerge
171 // TODO: add proper legalization rules.
172 getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
173
174 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
175 .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
176
178 all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars)));
179
180 getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
181 .legalForCartesianProduct(allPtrs, allPtrs);
182
183 getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
184
185 getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allIntScalarsAndVectors);
186
187 getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
188
189 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
190 .legalForCartesianProduct(allIntScalarsAndVectors,
191 allFloatScalarsAndVectors);
192
193 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
194 .legalForCartesianProduct(allFloatScalarsAndVectors,
195 allScalarsAndVectors);
196
197 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
198 .legalFor(allIntScalarsAndVectors);
199
201 allIntScalarsAndVectors, allIntScalarsAndVectors);
202
203 getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
204
206 all(typeInSet(0, allPtrsScalarsAndVectors),
207 typeInSet(1, allPtrsScalarsAndVectors)));
208
209 getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
210
211 getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
212
214 .legalForCartesianProduct(allPtrs, allIntScalars);
216 .legalForCartesianProduct(allIntScalars, allPtrs);
218 allPtrs, allIntScalars);
219
220 // ST.canDirectlyComparePointers() for pointer args is supported in
221 // legalizeCustom().
223 all(typeInSet(0, allBoolScalarsAndVectors),
224 typeInSet(1, allPtrsScalarsAndVectors)));
225
227 all(typeInSet(0, allBoolScalarsAndVectors),
228 typeInSet(1, allFloatScalarsAndVectors)));
229
230 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
231 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
232 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
233 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
234 .legalForCartesianProduct(allIntScalars, allWritablePtrs);
235
237 {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
238 .legalForCartesianProduct(allFloatScalars, allWritablePtrs);
239
240 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
241 .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
242
243 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
244 // TODO: add proper legalization rules.
245 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
246
247 getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
248 .alwaysLegal();
249
250 // Extensions.
251 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
252 .legalForCartesianProduct(allScalarsAndVectors);
253
254 // FP conversions.
255 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
256 .legalForCartesianProduct(allFloatScalarsAndVectors);
257
258 // Pointer-handling.
259 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
260
261 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
262 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
263
264 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
265 // tighten these requirements. Many of these math functions are only legal on
266 // specific bitwidths, so they are not selectable for
267 // allFloatScalarsAndVectors.
269 G_FEXP,
270 G_FEXP2,
271 G_FLOG,
272 G_FLOG2,
273 G_FLOG10,
274 G_FABS,
275 G_FMINNUM,
276 G_FMAXNUM,
277 G_FCEIL,
278 G_FCOS,
279 G_FSIN,
280 G_FTAN,
281 G_FSQRT,
282 G_FFLOOR,
283 G_FRINT,
284 G_FNEARBYINT,
285 G_INTRINSIC_ROUND,
286 G_INTRINSIC_TRUNC,
287 G_FMINIMUM,
288 G_FMAXIMUM,
289 G_INTRINSIC_ROUNDEVEN})
290 .legalFor(allFloatScalarsAndVectors);
291
292 getActionDefinitionsBuilder(G_FCOPYSIGN)
293 .legalForCartesianProduct(allFloatScalarsAndVectors,
294 allFloatScalarsAndVectors);
295
297 allFloatScalarsAndVectors, allIntScalarsAndVectors);
298
299 if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
301 {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
302 .legalForCartesianProduct(allIntScalarsAndVectors,
303 allIntScalarsAndVectors);
304
305 // Struct return types become a single scalar, so cannot easily legalize.
306 getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
307
308 // supported saturation arithmetic
309 getActionDefinitionsBuilder({G_SADDSAT, G_UADDSAT, G_SSUBSAT, G_USUBSAT})
310 .legalFor(allIntScalarsAndVectors);
311 }
312
314 verify(*ST.getInstrInfo());
315}
316
317static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
318 LegalizerHelper &Helper,
321 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
322 GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
323 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
324 .addDef(ConvReg)
325 .addUse(Reg);
326 return ConvReg;
327}
328
331 LostDebugLocObserver &LocObserver) const {
332 auto Opc = MI.getOpcode();
333 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
334 if (!isTypeFoldingSupported(Opc)) {
335 assert(Opc == TargetOpcode::G_ICMP);
336 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
337 auto &Op0 = MI.getOperand(2);
338 auto &Op1 = MI.getOperand(3);
339 Register Reg0 = Op0.getReg();
340 Register Reg1 = Op1.getReg();
342 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
343 if ((!ST->canDirectlyComparePointers() ||
345 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
346 LLT ConvT = LLT::scalar(ST->getPointerSize());
347 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
348 ST->getPointerSize());
349 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
350 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
351 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
352 }
353 return true;
354 }
355 // TODO: implement legalization for other opcodes.
356 return true;
357}
unsigned const MachineRegisterInfo * MRI
static void scalarize(BinaryOperator *BO, SmallVectorImpl< BinaryOperator * > &Replace)
IRTranslator LLVM IR MI
This file declares the MachineIRBuilder class.
ppc ctr loops verify
const SmallVectorImpl< MachineOperand > & Cond
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
bool isTypeFoldingSupported(unsigned Opcode)
static const std::set< unsigned > TypeFoldingSupportingOpcs
static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType, LegalizerHelper &Helper, MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR)
bool isTypeFoldingSupported(unsigned Opcode)
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:993
@ ICMP_EQ
equal
Definition: InstrTypes.h:1014
@ ICMP_NE
not equal
Definition: InstrTypes.h:1015
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:278
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
Definition: LowLevelType.h:42
static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits)
Get a low-level pointer in the given address space.
Definition: LowLevelType.h:57
static constexpr LLT fixed_vector(unsigned NumElements, unsigned ScalarSizeInBits)
Get a low-level fixed-width vector of some number of elements and element width.
Definition: LowLevelType.h:100
void computeTables()
Compute any ancillary tables needed to quickly decide how an operation should be handled.
LegalizeRuleSet & legalFor(std::initializer_list< LLT > Types)
The instruction is legal when type index 0 is any type in the given list.
LegalizeRuleSet & lower()
The instruction is lowered.
LegalizeRuleSet & custom()
Unconditionally custom lower.
LegalizeRuleSet & alwaysLegal()
LegalizeRuleSet & customIf(LegalityPredicate Predicate)
LegalizeRuleSet & scalarize(unsigned TypeIdx)
LegalizeRuleSet & legalForCartesianProduct(std::initializer_list< LLT > Types)
The instruction is legal when type indexes 0 and 1 are both in the given list.
LegalizeRuleSet & legalIf(LegalityPredicate Predicate)
The instruction is legal if predicate is true.
MachineIRBuilder & MIRBuilder
Expose MIRBuilder so clients can set their own RecordInsertInstruction functions.
LegalizeRuleSet & getActionDefinitionsBuilder(unsigned Opcode)
Get the action definition builder for the given opcode.
const LegacyLegalizerInfo & getLegacyLegalizerInfo() const
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
MachineFunction & getMF()
Getter for the function we currently build.
const MachineInstrBuilder & addUse(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register use operand.
const MachineInstrBuilder & addDef(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
Definition: MachineInstr.h:69
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
SPIRVType * getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
SPIRVType * getOrCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AQ=SPIRV::AccessQualifier::ReadWrite, bool EmitIR=true)
void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg, MachineFunction &MF)
SPIRVLegalizerInfo(const SPIRVSubtarget &ST)
bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const override
Called for instructions with the Custom LegalizationAction.
unsigned getPointerSize() const
bool canDirectlyComparePointers() const
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
LegalityPredicate typeInSet(unsigned TypeIdx, std::initializer_list< LLT > TypesInit)
True iff the given type index is one of the specified types.
Predicate all(Predicate P0, Predicate P1)
True iff P0 and P1 are true.
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18