LLVM 20.0.0git
SPIRVLegalizerInfo.cpp
Go to the documentation of this file.
1//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the targeting of the Machinelegalizer class for SPIR-V.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVLegalizerInfo.h"
14#include "SPIRV.h"
15#include "SPIRVGlobalRegistry.h"
16#include "SPIRVSubtarget.h"
22
23using namespace llvm;
24using namespace llvm::LegalizeActions;
25using namespace llvm::LegalityPredicates;
26
27static const std::set<unsigned> TypeFoldingSupportingOpcs = {
28 TargetOpcode::G_ADD,
29 TargetOpcode::G_FADD,
30 TargetOpcode::G_SUB,
31 TargetOpcode::G_FSUB,
32 TargetOpcode::G_MUL,
33 TargetOpcode::G_FMUL,
34 TargetOpcode::G_SDIV,
35 TargetOpcode::G_UDIV,
36 TargetOpcode::G_FDIV,
37 TargetOpcode::G_SREM,
38 TargetOpcode::G_UREM,
39 TargetOpcode::G_FREM,
40 TargetOpcode::G_FNEG,
41 TargetOpcode::G_CONSTANT,
42 TargetOpcode::G_FCONSTANT,
43 TargetOpcode::G_AND,
44 TargetOpcode::G_OR,
45 TargetOpcode::G_XOR,
46 TargetOpcode::G_SHL,
47 TargetOpcode::G_ASHR,
48 TargetOpcode::G_LSHR,
49 TargetOpcode::G_SELECT,
50 TargetOpcode::G_EXTRACT_VECTOR_ELT,
51};
52
53bool isTypeFoldingSupported(unsigned Opcode) {
54 return TypeFoldingSupportingOpcs.count(Opcode) > 0;
55}
56
57LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) {
58 return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) {
59 const LLT Ty = Query.Types[TypeIdx];
60 return IsExtendedInts && Ty.isValid() && Ty.isScalar();
61 };
62}
63
65 using namespace TargetOpcode;
66
67 this->ST = &ST;
68 GR = ST.getSPIRVGlobalRegistry();
69
70 const LLT s1 = LLT::scalar(1);
71 const LLT s8 = LLT::scalar(8);
72 const LLT s16 = LLT::scalar(16);
73 const LLT s32 = LLT::scalar(32);
74 const LLT s64 = LLT::scalar(64);
75
76 const LLT v16s64 = LLT::fixed_vector(16, 64);
77 const LLT v16s32 = LLT::fixed_vector(16, 32);
78 const LLT v16s16 = LLT::fixed_vector(16, 16);
79 const LLT v16s8 = LLT::fixed_vector(16, 8);
80 const LLT v16s1 = LLT::fixed_vector(16, 1);
81
82 const LLT v8s64 = LLT::fixed_vector(8, 64);
83 const LLT v8s32 = LLT::fixed_vector(8, 32);
84 const LLT v8s16 = LLT::fixed_vector(8, 16);
85 const LLT v8s8 = LLT::fixed_vector(8, 8);
86 const LLT v8s1 = LLT::fixed_vector(8, 1);
87
88 const LLT v4s64 = LLT::fixed_vector(4, 64);
89 const LLT v4s32 = LLT::fixed_vector(4, 32);
90 const LLT v4s16 = LLT::fixed_vector(4, 16);
91 const LLT v4s8 = LLT::fixed_vector(4, 8);
92 const LLT v4s1 = LLT::fixed_vector(4, 1);
93
94 const LLT v3s64 = LLT::fixed_vector(3, 64);
95 const LLT v3s32 = LLT::fixed_vector(3, 32);
96 const LLT v3s16 = LLT::fixed_vector(3, 16);
97 const LLT v3s8 = LLT::fixed_vector(3, 8);
98 const LLT v3s1 = LLT::fixed_vector(3, 1);
99
100 const LLT v2s64 = LLT::fixed_vector(2, 64);
101 const LLT v2s32 = LLT::fixed_vector(2, 32);
102 const LLT v2s16 = LLT::fixed_vector(2, 16);
103 const LLT v2s8 = LLT::fixed_vector(2, 8);
104 const LLT v2s1 = LLT::fixed_vector(2, 1);
105
106 const unsigned PSize = ST.getPointerSize();
107 const LLT p0 = LLT::pointer(0, PSize); // Function
108 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
109 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
110 const LLT p3 = LLT::pointer(3, PSize); // Workgroup
111 const LLT p4 = LLT::pointer(4, PSize); // Generic
112 const LLT p5 =
113 LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
114 const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
115
116 // TODO: remove copy-pasting here by using concatenation in some way.
117 auto allPtrsScalarsAndVectors = {
118 p0, p1, p2, p3, p4, p5, p6, s1, s8, s16,
119 s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16,
120 v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16,
121 v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
122
123 auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
124 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32,
125 v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
126 v16s8, v16s16, v16s32, v16s64};
127
128 auto allScalarsAndVectors = {
129 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
130 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
131 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
132
133 auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,
134 v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,
135 v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,
136 v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
137
138 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
139
140 auto allIntScalars = {s8, s16, s32, s64};
141
142 auto allFloatScalars = {s16, s32, s64};
143
144 auto allFloatScalarsAndVectors = {
145 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
146 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
147
148 auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1,
149 p2, p3, p4, p5, p6};
150
151 auto allPtrs = {p0, p1, p2, p3, p4, p5, p6};
152
153 bool IsExtendedInts =
154 ST.canUseExtension(
155 SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
156 ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
157 auto extendedScalarsAndVectors =
158 [IsExtendedInts](const LegalityQuery &Query) {
159 const LLT Ty = Query.Types[0];
160 return IsExtendedInts && Ty.isValid() && !Ty.isPointerOrPointerVector();
161 };
162 auto extendedScalarsAndVectorsProduct = [IsExtendedInts](
163 const LegalityQuery &Query) {
164 const LLT Ty1 = Query.Types[0], Ty2 = Query.Types[1];
165 return IsExtendedInts && Ty1.isValid() && Ty2.isValid() &&
166 !Ty1.isPointerOrPointerVector() && !Ty2.isPointerOrPointerVector();
167 };
168 auto extendedPtrsScalarsAndVectors =
169 [IsExtendedInts](const LegalityQuery &Query) {
170 const LLT Ty = Query.Types[0];
171 return IsExtendedInts && Ty.isValid();
172 };
173
174 for (auto Opc : TypeFoldingSupportingOpcs)
176
178
179 // TODO: add proper rules for vectors legalization.
181 {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
182 .alwaysLegal();
183
184 // Vector Reduction Operations
186 {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
187 G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
188 G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
189 G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
190 .legalFor(allVectors)
191 .scalarize(1)
192 .lower();
193
194 getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
195 .scalarize(2)
196 .lower();
197
198 // Merge/Unmerge
199 // TODO: add proper legalization rules.
200 getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
201
202 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
203 .legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs)));
204
206 all(typeInSet(0, allPtrs), typeInSet(1, allIntScalars)));
207
208 getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
209 .legalForCartesianProduct(allPtrs, allPtrs);
210
211 getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
212
213 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS,
214 G_BITREVERSE, G_SADDSAT, G_UADDSAT, G_SSUBSAT,
215 G_USUBSAT})
216 .legalFor(allIntScalarsAndVectors)
217 .legalIf(extendedScalarsAndVectors);
218
219 getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
220
221 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
222 .legalForCartesianProduct(allIntScalarsAndVectors,
223 allFloatScalarsAndVectors);
224
225 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
226 .legalForCartesianProduct(allFloatScalarsAndVectors,
227 allScalarsAndVectors);
228
230 .legalForCartesianProduct(allIntScalarsAndVectors)
231 .legalIf(extendedScalarsAndVectorsProduct);
232
233 // Extensions.
234 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
235 .legalForCartesianProduct(allScalarsAndVectors)
236 .legalIf(extendedScalarsAndVectorsProduct);
237
239 .legalFor(allPtrsScalarsAndVectors)
240 .legalIf(extendedPtrsScalarsAndVectors);
241
243 all(typeInSet(0, allPtrsScalarsAndVectors),
244 typeInSet(1, allPtrsScalarsAndVectors)));
245
246 getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
247
248 getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
249
251 .legalForCartesianProduct(allPtrs, allIntScalars)
252 .legalIf(
253 all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts)));
255 .legalForCartesianProduct(allIntScalars, allPtrs)
256 .legalIf(
257 all(typeOfExtendedScalars(0, IsExtendedInts), typeInSet(1, allPtrs)));
259 .legalForCartesianProduct(allPtrs, allIntScalars)
260 .legalIf(
261 all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts)));
262
263 // ST.canDirectlyComparePointers() for pointer args is supported in
264 // legalizeCustom().
266 all(typeInSet(0, allBoolScalarsAndVectors),
267 typeInSet(1, allPtrsScalarsAndVectors)));
268
270 all(typeInSet(0, allBoolScalarsAndVectors),
271 typeInSet(1, allFloatScalarsAndVectors)));
272
273 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
274 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
275 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
276 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
277 .legalForCartesianProduct(allIntScalars, allPtrs);
278
280 {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
281 .legalForCartesianProduct(allFloatScalars, allPtrs);
282
283 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
284 .legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allPtrs);
285
286 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
287 // TODO: add proper legalization rules.
288 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
289
290 getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
291 .alwaysLegal();
292
293 // FP conversions.
294 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
295 .legalForCartesianProduct(allFloatScalarsAndVectors);
296
297 // Pointer-handling.
298 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
299
300 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
301 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
302
303 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
304 // tighten these requirements. Many of these math functions are only legal on
305 // specific bitwidths, so they are not selectable for
306 // allFloatScalarsAndVectors.
308 G_FEXP,
309 G_FEXP2,
310 G_FLOG,
311 G_FLOG2,
312 G_FLOG10,
313 G_FABS,
314 G_FMINNUM,
315 G_FMAXNUM,
316 G_FCEIL,
317 G_FCOS,
318 G_FSIN,
319 G_FTAN,
320 G_FACOS,
321 G_FASIN,
322 G_FATAN,
323 G_FCOSH,
324 G_FSINH,
325 G_FTANH,
326 G_FSQRT,
327 G_FFLOOR,
328 G_FRINT,
329 G_FNEARBYINT,
330 G_INTRINSIC_ROUND,
331 G_INTRINSIC_TRUNC,
332 G_FMINIMUM,
333 G_FMAXIMUM,
334 G_INTRINSIC_ROUNDEVEN})
335 .legalFor(allFloatScalarsAndVectors);
336
337 getActionDefinitionsBuilder(G_FCOPYSIGN)
338 .legalForCartesianProduct(allFloatScalarsAndVectors,
339 allFloatScalarsAndVectors);
340
342 allFloatScalarsAndVectors, allIntScalarsAndVectors);
343
344 if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
346 {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
347 .legalForCartesianProduct(allIntScalarsAndVectors,
348 allIntScalarsAndVectors);
349
350 // Struct return types become a single scalar, so cannot easily legalize.
351 getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
352 }
353
355 verify(*ST.getInstrInfo());
356}
357
358static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
359 LegalizerHelper &Helper,
362 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
363 GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
364 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
365 .addDef(ConvReg)
366 .addUse(Reg);
367 return ConvReg;
368}
369
372 LostDebugLocObserver &LocObserver) const {
373 auto Opc = MI.getOpcode();
374 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
375 if (!isTypeFoldingSupported(Opc)) {
376 assert(Opc == TargetOpcode::G_ICMP);
377 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
378 auto &Op0 = MI.getOperand(2);
379 auto &Op1 = MI.getOperand(3);
380 Register Reg0 = Op0.getReg();
381 Register Reg1 = Op1.getReg();
383 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
384 if ((!ST->canDirectlyComparePointers() ||
386 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
387 LLT ConvT = LLT::scalar(ST->getPointerSize());
388 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
389 ST->getPointerSize());
390 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
391 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
392 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
393 }
394 return true;
395 }
396 // TODO: implement legalization for other opcodes.
397 return true;
398}
unsigned const MachineRegisterInfo * MRI
static void scalarize(BinaryOperator *BO, SmallVectorImpl< BinaryOperator * > &Replace)
IRTranslator LLVM IR MI
This file declares the MachineIRBuilder class.
ppc ctr loops verify
const SmallVectorImpl< MachineOperand > & Cond
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
bool isTypeFoldingSupported(unsigned Opcode)
static const std::set< unsigned > TypeFoldingSupportingOpcs
static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType, LegalizerHelper &Helper, MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR)
bool isTypeFoldingSupported(unsigned Opcode)
LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts)
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:757
@ ICMP_EQ
equal
Definition: InstrTypes.h:778
@ ICMP_NE
not equal
Definition: InstrTypes.h:779
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:266
constexpr bool isScalar() const
Definition: LowLevelType.h:146
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
Definition: LowLevelType.h:42
constexpr bool isValid() const
Definition: LowLevelType.h:145
static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits)
Get a low-level pointer in the given address space.
Definition: LowLevelType.h:57
static constexpr LLT fixed_vector(unsigned NumElements, unsigned ScalarSizeInBits)
Get a low-level fixed-width vector of some number of elements and element width.
Definition: LowLevelType.h:100
constexpr bool isPointerOrPointerVector() const
Definition: LowLevelType.h:153
void computeTables()
Compute any ancillary tables needed to quickly decide how an operation should be handled.
LegalizeRuleSet & legalFor(std::initializer_list< LLT > Types)
The instruction is legal when type index 0 is any type in the given list.
LegalizeRuleSet & lower()
The instruction is lowered.
LegalizeRuleSet & custom()
Unconditionally custom lower.
LegalizeRuleSet & alwaysLegal()
LegalizeRuleSet & customIf(LegalityPredicate Predicate)
LegalizeRuleSet & scalarize(unsigned TypeIdx)
LegalizeRuleSet & legalForCartesianProduct(std::initializer_list< LLT > Types)
The instruction is legal when type indexes 0 and 1 are both in the given list.
LegalizeRuleSet & legalIf(LegalityPredicate Predicate)
The instruction is legal if predicate is true.
MachineIRBuilder & MIRBuilder
Expose MIRBuilder so clients can set their own RecordInsertInstruction functions.
LegalizeRuleSet & getActionDefinitionsBuilder(unsigned Opcode)
Get the action definition builder for the given opcode.
const LegacyLegalizerInfo & getLegacyLegalizerInfo() const
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
MachineFunction & getMF()
Getter for the function we currently build.
const MachineInstrBuilder & addUse(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register use operand.
const MachineInstrBuilder & addDef(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
Definition: MachineInstr.h:69
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
SPIRVType * getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
SPIRVType * getOrCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AQ=SPIRV::AccessQualifier::ReadWrite, bool EmitIR=true)
void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg, MachineFunction &MF)
SPIRVLegalizerInfo(const SPIRVSubtarget &ST)
bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const override
Called for instructions with the Custom LegalizationAction.
unsigned getPointerSize() const
bool canDirectlyComparePointers() const
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
LegalityPredicate typeInSet(unsigned TypeIdx, std::initializer_list< LLT > TypesInit)
True iff the given type index is one of the specified types.
Predicate all(Predicate P0, Predicate P1)
True iff P0 and P1 are true.
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
std::function< bool(const LegalityQuery &)> LegalityPredicate
The LegalityQuery object bundles together all the information that's needed to decide whether a given...