LLVM 22.0.0git
SPIRVLegalizerInfo.cpp
Go to the documentation of this file.
1//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the targeting of the Machinelegalizer class for SPIR-V.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVLegalizerInfo.h"
14#include "SPIRV.h"
15#include "SPIRVGlobalRegistry.h"
16#include "SPIRVSubtarget.h"
22
23using namespace llvm;
24using namespace llvm::LegalizeActions;
25using namespace llvm::LegalityPredicates;
26
27LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) {
28 return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) {
29 const LLT Ty = Query.Types[TypeIdx];
30 return IsExtendedInts && Ty.isValid() && Ty.isScalar();
31 };
32}
33
35 using namespace TargetOpcode;
36
37 this->ST = &ST;
38 GR = ST.getSPIRVGlobalRegistry();
39
40 const LLT s1 = LLT::scalar(1);
41 const LLT s8 = LLT::scalar(8);
42 const LLT s16 = LLT::scalar(16);
43 const LLT s32 = LLT::scalar(32);
44 const LLT s64 = LLT::scalar(64);
45
46 const LLT v16s64 = LLT::fixed_vector(16, 64);
47 const LLT v16s32 = LLT::fixed_vector(16, 32);
48 const LLT v16s16 = LLT::fixed_vector(16, 16);
49 const LLT v16s8 = LLT::fixed_vector(16, 8);
50 const LLT v16s1 = LLT::fixed_vector(16, 1);
51
52 const LLT v8s64 = LLT::fixed_vector(8, 64);
53 const LLT v8s32 = LLT::fixed_vector(8, 32);
54 const LLT v8s16 = LLT::fixed_vector(8, 16);
55 const LLT v8s8 = LLT::fixed_vector(8, 8);
56 const LLT v8s1 = LLT::fixed_vector(8, 1);
57
58 const LLT v4s64 = LLT::fixed_vector(4, 64);
59 const LLT v4s32 = LLT::fixed_vector(4, 32);
60 const LLT v4s16 = LLT::fixed_vector(4, 16);
61 const LLT v4s8 = LLT::fixed_vector(4, 8);
62 const LLT v4s1 = LLT::fixed_vector(4, 1);
63
64 const LLT v3s64 = LLT::fixed_vector(3, 64);
65 const LLT v3s32 = LLT::fixed_vector(3, 32);
66 const LLT v3s16 = LLT::fixed_vector(3, 16);
67 const LLT v3s8 = LLT::fixed_vector(3, 8);
68 const LLT v3s1 = LLT::fixed_vector(3, 1);
69
70 const LLT v2s64 = LLT::fixed_vector(2, 64);
71 const LLT v2s32 = LLT::fixed_vector(2, 32);
72 const LLT v2s16 = LLT::fixed_vector(2, 16);
73 const LLT v2s8 = LLT::fixed_vector(2, 8);
74 const LLT v2s1 = LLT::fixed_vector(2, 1);
75
76 const unsigned PSize = ST.getPointerSize();
77 const LLT p0 = LLT::pointer(0, PSize); // Function
78 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
79 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
80 const LLT p3 = LLT::pointer(3, PSize); // Workgroup
81 const LLT p4 = LLT::pointer(4, PSize); // Generic
82 const LLT p5 =
83 LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
84 const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
85 const LLT p7 = LLT::pointer(7, PSize); // Input
86 const LLT p8 = LLT::pointer(8, PSize); // Output
87 const LLT p10 = LLT::pointer(10, PSize); // Private
88 const LLT p11 = LLT::pointer(11, PSize); // StorageBuffer
89 const LLT p12 = LLT::pointer(12, PSize); // Uniform
90
91 // TODO: remove copy-pasting here by using concatenation in some way.
92 auto allPtrsScalarsAndVectors = {
93 p0, p1, p2, p3, p4, p5, p6, p7, p8,
94 p10, p11, p12, s1, s8, s16, s32, s64, v2s1,
95 v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64,
96 v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16, v8s32,
97 v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
98
99 auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
100 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32,
101 v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
102 v16s8, v16s16, v16s32, v16s64};
103
104 auto allScalarsAndVectors = {
105 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
106 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
107 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
108
109 auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,
110 v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,
111 v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,
112 v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
113
114 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
115
116 auto allIntScalars = {s8, s16, s32, s64};
117
118 auto allFloatScalars = {s16, s32, s64};
119
120 auto allFloatScalarsAndVectors = {
121 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
122 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
123
124 auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, p2, p3,
125 p4, p5, p6, p7, p8, p10, p11, p12};
126
127 auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12};
128
129 bool IsExtendedInts =
130 ST.canUseExtension(
131 SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
132 ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
133 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
134 auto extendedScalarsAndVectors =
135 [IsExtendedInts](const LegalityQuery &Query) {
136 const LLT Ty = Query.Types[0];
137 return IsExtendedInts && Ty.isValid() && !Ty.isPointerOrPointerVector();
138 };
139 auto extendedScalarsAndVectorsProduct = [IsExtendedInts](
140 const LegalityQuery &Query) {
141 const LLT Ty1 = Query.Types[0], Ty2 = Query.Types[1];
142 return IsExtendedInts && Ty1.isValid() && Ty2.isValid() &&
143 !Ty1.isPointerOrPointerVector() && !Ty2.isPointerOrPointerVector();
144 };
145 auto extendedPtrsScalarsAndVectors =
146 [IsExtendedInts](const LegalityQuery &Query) {
147 const LLT Ty = Query.Types[0];
148 return IsExtendedInts && Ty.isValid();
149 };
150
153
155
156 // TODO: add proper rules for vectors legalization.
158 {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
159 .alwaysLegal();
160
161 // Vector Reduction Operations
163 {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
164 G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
165 G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
166 G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
167 .legalFor(allVectors)
168 .scalarize(1)
169 .lower();
170
171 getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
172 .scalarize(2)
173 .lower();
174
175 // Merge/Unmerge
176 // TODO: add proper legalization rules.
177 getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
178
179 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
180 .legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs)));
181
183 all(typeInSet(0, allPtrs), typeInSet(1, allIntScalars)));
184
185 getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
186 .legalForCartesianProduct(allPtrs, allPtrs);
187
188 getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
189
190 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS,
191 G_BITREVERSE, G_SADDSAT, G_UADDSAT, G_SSUBSAT,
192 G_USUBSAT, G_SCMP, G_UCMP})
193 .legalFor(allIntScalarsAndVectors)
194 .legalIf(extendedScalarsAndVectors);
195
196 getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA})
197 .legalFor(allFloatScalarsAndVectors);
198
199 getActionDefinitionsBuilder(G_STRICT_FLDEXP)
200 .legalForCartesianProduct(allFloatScalarsAndVectors, allIntScalars);
201
202 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
203 .legalForCartesianProduct(allIntScalarsAndVectors,
204 allFloatScalarsAndVectors);
205
206 getActionDefinitionsBuilder({G_FPTOSI_SAT, G_FPTOUI_SAT})
207 .legalForCartesianProduct(allIntScalarsAndVectors,
208 allFloatScalarsAndVectors);
209
210 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
211 .legalForCartesianProduct(allFloatScalarsAndVectors,
212 allScalarsAndVectors);
213
215 .legalForCartesianProduct(allIntScalarsAndVectors)
216 .legalIf(extendedScalarsAndVectorsProduct);
217
218 // Extensions.
219 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
220 .legalForCartesianProduct(allScalarsAndVectors)
221 .legalIf(extendedScalarsAndVectorsProduct);
222
224 .legalFor(allPtrsScalarsAndVectors)
225 .legalIf(extendedPtrsScalarsAndVectors);
226
228 all(typeInSet(0, allPtrsScalarsAndVectors),
229 typeInSet(1, allPtrsScalarsAndVectors)));
230
231 getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
232
233 getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
234
236 .legalForCartesianProduct(allPtrs, allIntScalars)
237 .legalIf(
238 all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts)));
240 .legalForCartesianProduct(allIntScalars, allPtrs)
241 .legalIf(
242 all(typeOfExtendedScalars(0, IsExtendedInts), typeInSet(1, allPtrs)));
244 .legalForCartesianProduct(allPtrs, allIntScalars)
245 .legalIf(
246 all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts)));
247
248 // ST.canDirectlyComparePointers() for pointer args is supported in
249 // legalizeCustom().
251 all(typeInSet(0, allBoolScalarsAndVectors),
252 typeInSet(1, allPtrsScalarsAndVectors)));
253
255 all(typeInSet(0, allBoolScalarsAndVectors),
256 typeInSet(1, allFloatScalarsAndVectors)));
257
258 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
259 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
260 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
261 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
262 .legalForCartesianProduct(allIntScalars, allPtrs);
263
265 {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
266 .legalForCartesianProduct(allFloatScalars, allPtrs);
267
268 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
269 .legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allPtrs);
270
271 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
272 // TODO: add proper legalization rules.
273 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
274
276 {G_UADDO, G_SADDO, G_USUBO, G_SSUBO, G_UMULO, G_SMULO})
277 .alwaysLegal();
278
279 getActionDefinitionsBuilder({G_LROUND, G_LLROUND})
280 .legalForCartesianProduct(allFloatScalarsAndVectors,
281 allIntScalarsAndVectors);
282
283 // FP conversions.
284 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
285 .legalForCartesianProduct(allFloatScalarsAndVectors);
286
287 // Pointer-handling.
288 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
289
290 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
291 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
292
294 allFloatScalarsAndVectors, {s32, v2s32, v3s32, v4s32, v8s32, v16s32});
295
296 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
297 // tighten these requirements. Many of these math functions are only legal on
298 // specific bitwidths, so they are not selectable for
299 // allFloatScalarsAndVectors.
300 getActionDefinitionsBuilder({G_STRICT_FSQRT,
301 G_FPOW,
302 G_FEXP,
303 G_FMODF,
304 G_FEXP2,
305 G_FLOG,
306 G_FLOG2,
307 G_FLOG10,
308 G_FABS,
309 G_FMINNUM,
310 G_FMAXNUM,
311 G_FCEIL,
312 G_FCOS,
313 G_FSIN,
314 G_FTAN,
315 G_FACOS,
316 G_FASIN,
317 G_FATAN,
318 G_FATAN2,
319 G_FCOSH,
320 G_FSINH,
321 G_FTANH,
322 G_FSQRT,
323 G_FFLOOR,
324 G_FRINT,
325 G_FNEARBYINT,
326 G_INTRINSIC_ROUND,
327 G_INTRINSIC_TRUNC,
328 G_FMINIMUM,
329 G_FMAXIMUM,
330 G_INTRINSIC_ROUNDEVEN})
331 .legalFor(allFloatScalarsAndVectors);
332
333 getActionDefinitionsBuilder(G_FCOPYSIGN)
334 .legalForCartesianProduct(allFloatScalarsAndVectors,
335 allFloatScalarsAndVectors);
336
338 allFloatScalarsAndVectors, allIntScalarsAndVectors);
339
340 if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
342 {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
343 .legalForCartesianProduct(allIntScalarsAndVectors,
344 allIntScalarsAndVectors);
345
346 // Struct return types become a single scalar, so cannot easily legalize.
347 getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
348 }
349
350 getActionDefinitionsBuilder(G_IS_FPCLASS).custom();
351
353 verify(*ST.getInstrInfo());
354}
355
357 LegalizerHelper &Helper,
360 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
361 MRI.setRegClass(ConvReg, GR->getRegClass(SpvType));
362 GR->assignSPIRVTypeToVReg(SpvType, ConvReg, Helper.MIRBuilder.getMF());
363 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
364 .addDef(ConvReg)
365 .addUse(Reg);
366 return ConvReg;
367}
368
371 LostDebugLocObserver &LocObserver) const {
372 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
373 switch (MI.getOpcode()) {
374 default:
375 // TODO: implement legalization for other opcodes.
376 return true;
377 case TargetOpcode::G_IS_FPCLASS:
378 return legalizeIsFPClass(Helper, MI, LocObserver);
379 case TargetOpcode::G_ICMP: {
380 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
381 auto &Op0 = MI.getOperand(2);
382 auto &Op1 = MI.getOperand(3);
383 Register Reg0 = Op0.getReg();
384 Register Reg1 = Op1.getReg();
386 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
387 if ((!ST->canDirectlyComparePointers() ||
389 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
390 LLT ConvT = LLT::scalar(ST->getPointerSize());
391 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
392 ST->getPointerSize());
393 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(
394 LLVMTy, Helper.MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
395 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
396 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
397 }
398 return true;
399 }
400 }
401}
402
403// Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted
404// to ensure that all instructions created during the lowering have SPIR-V types
405// assigned to them.
406bool SPIRVLegalizerInfo::legalizeIsFPClass(
408 LostDebugLocObserver &LocObserver) const {
409 auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
410 FPClassTest Mask = static_cast<FPClassTest>(MI.getOperand(2).getImm());
411
412 auto &MIRBuilder = Helper.MIRBuilder;
413 auto &MF = MIRBuilder.getMF();
414 MachineRegisterInfo &MRI = MF.getRegInfo();
415
416 Type *LLVMDstTy =
417 IntegerType::get(MIRBuilder.getContext(), DstTy.getScalarSizeInBits());
418 if (DstTy.isVector())
419 LLVMDstTy = VectorType::get(LLVMDstTy, DstTy.getElementCount());
420 SPIRVType *SPIRVDstTy = GR->getOrCreateSPIRVType(
421 LLVMDstTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
422 /*EmitIR*/ true);
423
424 unsigned BitSize = SrcTy.getScalarSizeInBits();
425 const fltSemantics &Semantics = getFltSemanticForLLT(SrcTy.getScalarType());
426
427 LLT IntTy = LLT::scalar(BitSize);
428 Type *LLVMIntTy = IntegerType::get(MIRBuilder.getContext(), BitSize);
429 if (SrcTy.isVector()) {
430 IntTy = LLT::vector(SrcTy.getElementCount(), IntTy);
431 LLVMIntTy = VectorType::get(LLVMIntTy, SrcTy.getElementCount());
432 }
433 SPIRVType *SPIRVIntTy = GR->getOrCreateSPIRVType(
434 LLVMIntTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
435 /*EmitIR*/ true);
436
437 // Clang doesn't support capture of structured bindings:
438 LLT DstTyCopy = DstTy;
439 const auto assignSPIRVTy = [&](MachineInstrBuilder &&MI) {
440 // Assign this MI's (assumed only) destination to one of the two types we
441 // expect: either the G_IS_FPCLASS's destination type, or the integer type
442 // bitcast from the source type.
443 LLT MITy = MRI.getType(MI.getReg(0));
444 assert((MITy == IntTy || MITy == DstTyCopy) &&
445 "Unexpected LLT type while lowering G_IS_FPCLASS");
446 auto *SPVTy = MITy == IntTy ? SPIRVIntTy : SPIRVDstTy;
447 GR->assignSPIRVTypeToVReg(SPVTy, MI.getReg(0), MF);
448 return MI;
449 };
450
451 // Helper to build and assign a constant in one go
452 const auto buildSPIRVConstant = [&](LLT Ty, auto &&C) -> MachineInstrBuilder {
453 if (!Ty.isFixedVector())
454 return assignSPIRVTy(MIRBuilder.buildConstant(Ty, C));
455 auto ScalarC = MIRBuilder.buildConstant(Ty.getScalarType(), C);
456 assert((Ty == IntTy || Ty == DstTyCopy) &&
457 "Unexpected LLT type while lowering constant for G_IS_FPCLASS");
458 SPIRVType *VecEltTy = GR->getOrCreateSPIRVType(
459 (Ty == IntTy ? LLVMIntTy : LLVMDstTy)->getScalarType(), MIRBuilder,
460 SPIRV::AccessQualifier::ReadWrite,
461 /*EmitIR*/ true);
462 GR->assignSPIRVTypeToVReg(VecEltTy, ScalarC.getReg(0), MF);
463 return assignSPIRVTy(MIRBuilder.buildSplatBuildVector(Ty, ScalarC));
464 };
465
466 if (Mask == fcNone) {
467 MIRBuilder.buildCopy(DstReg, buildSPIRVConstant(DstTy, 0));
468 MI.eraseFromParent();
469 return true;
470 }
471 if (Mask == fcAllFlags) {
472 MIRBuilder.buildCopy(DstReg, buildSPIRVConstant(DstTy, 1));
473 MI.eraseFromParent();
474 return true;
475 }
476
477 // Note that rather than creating a COPY here (between a floating-point and
478 // integer type of the same size) we create a SPIR-V bitcast immediately. We
479 // can't create a G_BITCAST because the LLTs are the same, and we can't seem
480 // to correctly lower COPYs to SPIR-V bitcasts at this moment.
481 Register ResVReg = MRI.createGenericVirtualRegister(IntTy);
482 MRI.setRegClass(ResVReg, GR->getRegClass(SPIRVIntTy));
483 GR->assignSPIRVTypeToVReg(SPIRVIntTy, ResVReg, Helper.MIRBuilder.getMF());
484 auto AsInt = MIRBuilder.buildInstr(SPIRV::OpBitcast)
485 .addDef(ResVReg)
486 .addUse(GR->getSPIRVTypeID(SPIRVIntTy))
487 .addUse(SrcReg);
488 AsInt = assignSPIRVTy(std::move(AsInt));
489
490 // Various masks.
491 APInt SignBit = APInt::getSignMask(BitSize);
492 APInt ValueMask = APInt::getSignedMaxValue(BitSize); // All bits but sign.
493 APInt Inf = APFloat::getInf(Semantics).bitcastToAPInt(); // Exp and int bit.
494 APInt ExpMask = Inf;
495 APInt AllOneMantissa = APFloat::getLargest(Semantics).bitcastToAPInt() & ~Inf;
496 APInt QNaNBitMask =
497 APInt::getOneBitSet(BitSize, AllOneMantissa.getActiveBits() - 1);
498 APInt InversionMask = APInt::getAllOnes(DstTy.getScalarSizeInBits());
499
500 auto SignBitC = buildSPIRVConstant(IntTy, SignBit);
501 auto ValueMaskC = buildSPIRVConstant(IntTy, ValueMask);
502 auto InfC = buildSPIRVConstant(IntTy, Inf);
503 auto ExpMaskC = buildSPIRVConstant(IntTy, ExpMask);
504 auto ZeroC = buildSPIRVConstant(IntTy, 0);
505
506 auto Abs = assignSPIRVTy(MIRBuilder.buildAnd(IntTy, AsInt, ValueMaskC));
507 auto Sign = assignSPIRVTy(
508 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_NE, DstTy, AsInt, Abs));
509
510 auto Res = buildSPIRVConstant(DstTy, 0);
511
512 const auto appendToRes = [&](MachineInstrBuilder &&ToAppend) {
513 Res = assignSPIRVTy(
514 MIRBuilder.buildOr(DstTyCopy, Res, assignSPIRVTy(std::move(ToAppend))));
515 };
516
517 // Tests that involve more than one class should be processed first.
518 if ((Mask & fcFinite) == fcFinite) {
519 // finite(V) ==> abs(V) u< exp_mask
520 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, Abs,
521 ExpMaskC));
522 Mask &= ~fcFinite;
523 } else if ((Mask & fcFinite) == fcPosFinite) {
524 // finite(V) && V > 0 ==> V u< exp_mask
525 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, AsInt,
526 ExpMaskC));
527 Mask &= ~fcPosFinite;
528 } else if ((Mask & fcFinite) == fcNegFinite) {
529 // finite(V) && V < 0 ==> abs(V) u< exp_mask && signbit == 1
530 auto Cmp = assignSPIRVTy(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT,
531 DstTy, Abs, ExpMaskC));
532 appendToRes(MIRBuilder.buildAnd(DstTy, Cmp, Sign));
533 Mask &= ~fcNegFinite;
534 }
535
536 if (FPClassTest PartialCheck = Mask & (fcZero | fcSubnormal)) {
537 // fcZero | fcSubnormal => test all exponent bits are 0
538 // TODO: Handle sign bit specific cases
539 // TODO: Handle inverted case
540 if (PartialCheck == (fcZero | fcSubnormal)) {
541 auto ExpBits = assignSPIRVTy(MIRBuilder.buildAnd(IntTy, AsInt, ExpMaskC));
542 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
543 ExpBits, ZeroC));
544 Mask &= ~PartialCheck;
545 }
546 }
547
548 // Check for individual classes.
549 if (FPClassTest PartialCheck = Mask & fcZero) {
550 if (PartialCheck == fcPosZero)
551 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
552 AsInt, ZeroC));
553 else if (PartialCheck == fcZero)
554 appendToRes(
555 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, ZeroC));
556 else // fcNegZero
557 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
558 AsInt, SignBitC));
559 }
560
561 if (FPClassTest PartialCheck = Mask & fcSubnormal) {
562 // issubnormal(V) ==> unsigned(abs(V) - 1) u< (all mantissa bits set)
563 // issubnormal(V) && V>0 ==> unsigned(V - 1) u< (all mantissa bits set)
564 auto V = (PartialCheck == fcPosSubnormal) ? AsInt : Abs;
565 auto OneC = buildSPIRVConstant(IntTy, 1);
566 auto VMinusOne = MIRBuilder.buildSub(IntTy, V, OneC);
567 auto SubnormalRes = assignSPIRVTy(
568 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, VMinusOne,
569 buildSPIRVConstant(IntTy, AllOneMantissa)));
570 if (PartialCheck == fcNegSubnormal)
571 SubnormalRes = MIRBuilder.buildAnd(DstTy, SubnormalRes, Sign);
572 appendToRes(std::move(SubnormalRes));
573 }
574
575 if (FPClassTest PartialCheck = Mask & fcInf) {
576 if (PartialCheck == fcPosInf)
577 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
578 AsInt, InfC));
579 else if (PartialCheck == fcInf)
580 appendToRes(
581 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, InfC));
582 else { // fcNegInf
583 APInt NegInf = APFloat::getInf(Semantics, true).bitcastToAPInt();
584 auto NegInfC = buildSPIRVConstant(IntTy, NegInf);
585 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
586 AsInt, NegInfC));
587 }
588 }
589
590 if (FPClassTest PartialCheck = Mask & fcNan) {
591 auto InfWithQnanBitC =
592 buildSPIRVConstant(IntTy, std::move(Inf) | QNaNBitMask);
593 if (PartialCheck == fcNan) {
594 // isnan(V) ==> abs(V) u> int(inf)
595 appendToRes(
596 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
597 } else if (PartialCheck == fcQNan) {
598 // isquiet(V) ==> abs(V) u>= (unsigned(Inf) | quiet_bit)
599 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGE, DstTy, Abs,
600 InfWithQnanBitC));
601 } else { // fcSNan
602 // issignaling(V) ==> abs(V) u> unsigned(Inf) &&
603 // abs(V) u< (unsigned(Inf) | quiet_bit)
604 auto IsNan = assignSPIRVTy(
605 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
606 auto IsNotQnan = assignSPIRVTy(MIRBuilder.buildICmp(
607 CmpInst::Predicate::ICMP_ULT, DstTy, Abs, InfWithQnanBitC));
608 appendToRes(MIRBuilder.buildAnd(DstTy, IsNan, IsNotQnan));
609 }
610 }
611
612 if (FPClassTest PartialCheck = Mask & fcNormal) {
613 // isnormal(V) ==> (0 u< exp u< max_exp) ==> (unsigned(exp-1) u<
614 // (max_exp-1))
615 APInt ExpLSB = ExpMask & ~(ExpMask.shl(1));
616 auto ExpMinusOne = assignSPIRVTy(
617 MIRBuilder.buildSub(IntTy, Abs, buildSPIRVConstant(IntTy, ExpLSB)));
618 APInt MaxExpMinusOne = std::move(ExpMask) - ExpLSB;
619 auto NormalRes = assignSPIRVTy(
620 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, ExpMinusOne,
621 buildSPIRVConstant(IntTy, MaxExpMinusOne)));
622 if (PartialCheck == fcNegNormal)
623 NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, Sign);
624 else if (PartialCheck == fcPosNormal) {
625 auto PosSign = assignSPIRVTy(MIRBuilder.buildXor(
626 DstTy, Sign, buildSPIRVConstant(DstTy, InversionMask)));
627 NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, PosSign);
628 }
629 appendToRes(std::move(NormalRes));
630 }
631
632 MIRBuilder.buildCopy(DstReg, Res);
633 MI.eraseFromParent();
634 return true;
635}
unsigned const MachineRegisterInfo * MRI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static void scalarize(Instruction *I, SmallVectorImpl< Instruction * > &Replace)
Definition ExpandFp.cpp:942
IRTranslator LLVM IR MI
This file declares the MachineIRBuilder class.
Register Reg
Promote Memory to Register
Definition Mem2Reg.cpp:110
ppc ctr loops verify
const SmallVectorImpl< MachineOperand > & Cond
static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType, LegalizerHelper &Helper, MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR)
LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts)
APInt bitcastToAPInt() const
Definition APFloat.h:1353
static APFloat getLargest(const fltSemantics &Sem, bool Negative=false)
Returns the largest finite number in the given semantics.
Definition APFloat.h:1138
static APFloat getInf(const fltSemantics &Sem, bool Negative=false)
Factory for Positive and Negative Infinity.
Definition APFloat.h:1098
static APInt getAllOnes(unsigned numBits)
Return an APInt of a specified width with all bits set.
Definition APInt.h:234
static APInt getSignMask(unsigned BitWidth)
Get the SignMask for a specific bit width.
Definition APInt.h:229
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1512
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:209
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:873
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:239
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:678
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:702
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:701
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:703
@ ICMP_NE
not equal
Definition InstrTypes.h:700
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:319
static constexpr LLT vector(ElementCount EC, unsigned ScalarSizeInBits)
Get a low-level vector of some number of elements and element width.
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
constexpr bool isValid() const
static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits)
Get a low-level pointer in the given address space.
static constexpr LLT fixed_vector(unsigned NumElements, unsigned ScalarSizeInBits)
Get a low-level fixed-width vector of some number of elements and element width.
constexpr bool isPointerOrPointerVector() const
constexpr bool isFixedVector() const
Returns true if the LLT is a fixed vector.
constexpr LLT getScalarType() const
LLVM_ABI void computeTables()
Compute any ancillary tables needed to quickly decide how an operation should be handled.
LegalizeRuleSet & legalFor(std::initializer_list< LLT > Types)
The instruction is legal when type index 0 is any type in the given list.
LegalizeRuleSet & lower()
The instruction is lowered.
LegalizeRuleSet & custom()
Unconditionally custom lower.
LegalizeRuleSet & alwaysLegal()
LegalizeRuleSet & customIf(LegalityPredicate Predicate)
LegalizeRuleSet & scalarize(unsigned TypeIdx)
LegalizeRuleSet & legalForCartesianProduct(std::initializer_list< LLT > Types)
The instruction is legal when type indexes 0 and 1 are both in the given list.
LegalizeRuleSet & legalIf(LegalityPredicate Predicate)
The instruction is legal if predicate is true.
MachineIRBuilder & MIRBuilder
Expose MIRBuilder so clients can set their own RecordInsertInstruction functions.
LegalizeRuleSet & getActionDefinitionsBuilder(unsigned Opcode)
Get the action definition builder for the given opcode.
const LegacyLegalizerInfo & getLegacyLegalizerInfo() const
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
MachineFunction & getMF()
Getter for the function we currently build.
const MachineInstrBuilder & addUse(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register use operand.
const MachineInstrBuilder & addDef(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
Wrapper class representing virtual and physical registers.
Definition Register.h:19
void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg, const MachineFunction &MF)
SPIRVType * getOrCreateSPIRVType(const Type *Type, MachineInstr &I, SPIRV::AccessQualifier::AccessQualifier AQ, bool EmitIR)
const TargetRegisterClass * getRegClass(SPIRVType *SpvType) const
SPIRVLegalizerInfo(const SPIRVSubtarget &ST)
bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const override
Called for instructions with the Custom LegalizationAction.
SPIRVGlobalRegistry * getSPIRVGlobalRegistry() const
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
LLVM_ABI LegalityPredicate typeInSet(unsigned TypeIdx, std::initializer_list< LLT > TypesInit)
True iff the given type index is one of the specified types.
Predicate all(Predicate P0, Predicate P1)
True iff P0 and P1 are true.
Invariant opcodes: All instruction sets have these as their low opcodes.
This is an optimization pass for GlobalISel generic memory operations.
LLVM_ABI const llvm::fltSemantics & getFltSemanticForLLT(LLT Ty)
Get the appropriate floating point arithmetic semantic based on the bit size of the given scalar LLT.
std::function< bool(const LegalityQuery &)> LegalityPredicate
MachineInstr * getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI)
FPClassTest
Floating-point class tests, supported by 'is_fpclass' intrinsic.
const MachineInstr SPIRVType
const std::set< unsigned > & getTypeFoldingSupportedOpcodes()
The LegalityQuery object bundles together all the information that's needed to decide whether a given...