LLVM 22.0.0git
SPIRVLegalizerInfo.cpp
Go to the documentation of this file.
1//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the targeting of the Machinelegalizer class for SPIR-V.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVLegalizerInfo.h"
14#include "SPIRV.h"
15#include "SPIRVGlobalRegistry.h"
16#include "SPIRVSubtarget.h"
17#include "SPIRVUtils.h"
24#include "llvm/IR/IntrinsicsSPIRV.h"
25#include "llvm/Support/Debug.h"
27
28using namespace llvm;
29using namespace llvm::LegalizeActions;
30using namespace llvm::LegalityPredicates;
31
32#define DEBUG_TYPE "spirv-legalizer"
33
34LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) {
35 return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) {
36 const LLT Ty = Query.Types[TypeIdx];
37 return IsExtendedInts && Ty.isValid() && Ty.isScalar();
38 };
39}
40
42 using namespace TargetOpcode;
43
44 this->ST = &ST;
45 GR = ST.getSPIRVGlobalRegistry();
46
47 const LLT s1 = LLT::scalar(1);
48 const LLT s8 = LLT::scalar(8);
49 const LLT s16 = LLT::scalar(16);
50 const LLT s32 = LLT::scalar(32);
51 const LLT s64 = LLT::scalar(64);
52 const LLT s128 = LLT::scalar(128);
53
54 const LLT v16s64 = LLT::fixed_vector(16, 64);
55 const LLT v16s32 = LLT::fixed_vector(16, 32);
56 const LLT v16s16 = LLT::fixed_vector(16, 16);
57 const LLT v16s8 = LLT::fixed_vector(16, 8);
58 const LLT v16s1 = LLT::fixed_vector(16, 1);
59
60 const LLT v8s64 = LLT::fixed_vector(8, 64);
61 const LLT v8s32 = LLT::fixed_vector(8, 32);
62 const LLT v8s16 = LLT::fixed_vector(8, 16);
63 const LLT v8s8 = LLT::fixed_vector(8, 8);
64 const LLT v8s1 = LLT::fixed_vector(8, 1);
65
66 const LLT v4s64 = LLT::fixed_vector(4, 64);
67 const LLT v4s32 = LLT::fixed_vector(4, 32);
68 const LLT v4s16 = LLT::fixed_vector(4, 16);
69 const LLT v4s8 = LLT::fixed_vector(4, 8);
70 const LLT v4s1 = LLT::fixed_vector(4, 1);
71
72 const LLT v3s64 = LLT::fixed_vector(3, 64);
73 const LLT v3s32 = LLT::fixed_vector(3, 32);
74 const LLT v3s16 = LLT::fixed_vector(3, 16);
75 const LLT v3s8 = LLT::fixed_vector(3, 8);
76 const LLT v3s1 = LLT::fixed_vector(3, 1);
77
78 const LLT v2s64 = LLT::fixed_vector(2, 64);
79 const LLT v2s32 = LLT::fixed_vector(2, 32);
80 const LLT v2s16 = LLT::fixed_vector(2, 16);
81 const LLT v2s8 = LLT::fixed_vector(2, 8);
82 const LLT v2s1 = LLT::fixed_vector(2, 1);
83
84 const unsigned PSize = ST.getPointerSize();
85 const LLT p0 = LLT::pointer(0, PSize); // Function
86 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
87 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
88 const LLT p3 = LLT::pointer(3, PSize); // Workgroup
89 const LLT p4 = LLT::pointer(4, PSize); // Generic
90 const LLT p5 =
91 LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
92 const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
93 const LLT p7 = LLT::pointer(7, PSize); // Input
94 const LLT p8 = LLT::pointer(8, PSize); // Output
95 const LLT p9 =
96 LLT::pointer(9, PSize); // CodeSectionINTEL, SPV_INTEL_function_pointers
97 const LLT p10 = LLT::pointer(10, PSize); // Private
98 const LLT p11 = LLT::pointer(11, PSize); // StorageBuffer
99 const LLT p12 = LLT::pointer(12, PSize); // Uniform
100 const LLT p13 = LLT::pointer(13, PSize); // PushConstant
101
102 // TODO: remove copy-pasting here by using concatenation in some way.
103 auto allPtrsScalarsAndVectors = {
104 p0, p1, p2, p3, p4, p5, p6, p7, p8,
105 p9, p10, p11, p12, p13, s1, s8, s16, s32,
106 s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16,
107 v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8,
108 v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
109
110 auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
111 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32,
112 v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
113 v16s8, v16s16, v16s32, v16s64};
114
115 auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64,
116 v3s1, v3s8, v3s16, v3s32, v3s64,
117 v4s1, v4s8, v4s16, v4s32, v4s64};
118
119 auto allScalars = {s1, s8, s16, s32, s64};
120
121 auto allScalarsAndVectors = {
122 s1, s8, s16, s32, s64, s128, v2s1, v2s8,
123 v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64,
124 v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16,
125 v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
126
127 auto allIntScalarsAndVectors = {
128 s8, s16, s32, s64, s128, v2s8, v2s16, v2s32, v2s64,
129 v3s8, v3s16, v3s32, v3s64, v4s8, v4s16, v4s32, v4s64, v8s8,
130 v8s16, v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
131
132 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
133
134 auto allIntScalars = {s8, s16, s32, s64, s128};
135
136 auto allFloatScalarsAndF16Vector2AndVector4s = {s16, s32, s64, v2s16, v4s16};
137
138 auto allFloatScalarsAndVectors = {
139 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
140 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
141
142 auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1,
143 p2, p3, p4, p5, p6, p7,
144 p8, p9, p10, p11, p12, p13};
145
146 auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13};
147
148 auto &allowedVectorTypes = ST.isShader() ? allShaderVectors : allVectors;
149
150 bool IsExtendedInts =
151 ST.canUseExtension(
152 SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers) ||
153 ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
154 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
155 auto extendedScalarsAndVectors =
156 [IsExtendedInts](const LegalityQuery &Query) {
157 const LLT Ty = Query.Types[0];
158 return IsExtendedInts && Ty.isValid() && !Ty.isPointerOrPointerVector();
159 };
160 auto extendedScalarsAndVectorsProduct = [IsExtendedInts](
161 const LegalityQuery &Query) {
162 const LLT Ty1 = Query.Types[0], Ty2 = Query.Types[1];
163 return IsExtendedInts && Ty1.isValid() && Ty2.isValid() &&
164 !Ty1.isPointerOrPointerVector() && !Ty2.isPointerOrPointerVector();
165 };
166 auto extendedPtrsScalarsAndVectors =
167 [IsExtendedInts](const LegalityQuery &Query) {
168 const LLT Ty = Query.Types[0];
169 return IsExtendedInts && Ty.isValid();
170 };
171
172 // The universal validation rules in the SPIR-V specification state that
173 // vector sizes are typically limited to 2, 3, or 4. However, larger vector
174 // sizes (8 and 16) are enabled when the Kernel capability is present. For
175 // shader execution models, vector sizes are strictly limited to 4. In
176 // non-shader contexts, vector sizes of 8 and 16 are also permitted, but
177 // arbitrary sizes (e.g., 6 or 11) are not.
178 uint32_t MaxVectorSize = ST.isShader() ? 4 : 16;
179 LLVM_DEBUG(dbgs() << "MaxVectorSize: " << MaxVectorSize << "\n");
180
181 for (auto Opc : getTypeFoldingSupportedOpcodes()) {
182 switch (Opc) {
183 case G_EXTRACT_VECTOR_ELT:
184 case G_UREM:
185 case G_SREM:
186 case G_UDIV:
187 case G_SDIV:
188 case G_FREM:
189 break;
190 default:
192 .customFor(allScalars)
193 .customFor(allowedVectorTypes)
197 0, ElementCount::getFixed(MaxVectorSize)))
198 .custom();
199 break;
200 }
201 }
202
203 getActionDefinitionsBuilder({G_UREM, G_SREM, G_SDIV, G_UDIV, G_FREM})
204 .customFor(allScalars)
205 .customFor(allowedVectorTypes)
209 0, ElementCount::getFixed(MaxVectorSize)))
210 .custom();
211
212 getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA})
213 .legalFor(allScalars)
214 .legalFor(allowedVectorTypes)
218 0, ElementCount::getFixed(MaxVectorSize)))
219 .alwaysLegal();
220
221 getActionDefinitionsBuilder(G_INTRINSIC_W_SIDE_EFFECTS).custom();
222
223 getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
224 .legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes)
226 .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
228 .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize));
229
230 getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
234 1, ElementCount::getFixed(MaxVectorSize)))
235 .custom();
236
237 getActionDefinitionsBuilder(G_INSERT_VECTOR_ELT)
241 0, ElementCount::getFixed(MaxVectorSize)))
242 .custom();
243
244 // Illegal G_UNMERGE_VALUES instructions should be handled
245 // during the combine phase.
246 getActionDefinitionsBuilder(G_BUILD_VECTOR)
248
249 // When entering the legalizer, there should be no G_BITCAST instructions.
250 // They should all be calls to the `spv_bitcast` intrinsic. The call to
251 // the intrinsic will be converted to a G_BITCAST during legalization if
252 // the vectors are not legal. After using the rules to legalize a G_BITCAST,
253 // we turn it back into a call to the intrinsic with a custom rule to avoid
254 // potential machine verifier failures.
260 0, ElementCount::getFixed(MaxVectorSize)))
261 .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
262 .custom();
263
264 // If the result is still illegal, the combiner should be able to remove it.
265 getActionDefinitionsBuilder(G_CONCAT_VECTORS)
266 .legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes);
267
268 getActionDefinitionsBuilder(G_SPLAT_VECTOR)
269 .legalFor(allowedVectorTypes)
273 .alwaysLegal();
274
275 // Vector Reduction Operations
277 {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
278 G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
279 G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
280 G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
281 .legalFor(allowedVectorTypes)
282 .scalarize(1)
283 .lower();
284
285 getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
286 .scalarize(2)
287 .lower();
288
289 // Illegal G_UNMERGE_VALUES instructions should be handled
290 // during the combine phase.
291 getActionDefinitionsBuilder(G_UNMERGE_VALUES)
293
294 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
295 .unsupportedIf(LegalityPredicates::any(typeIs(0, p9), typeIs(1, p9)))
296 .legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs)));
297
299 .unsupportedIf(typeIs(0, p9))
300 .legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allIntScalars)));
301
302 getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
305 all(typeIsNot(0, p9), typeIs(1, p9))))
306 .legalForCartesianProduct(allPtrs, allPtrs);
307
308 // Should we be legalizing bad scalar sizes like s5 here instead
309 // of handling them in the instruction selector?
310 getActionDefinitionsBuilder({G_LOAD, G_STORE})
311 .unsupportedIf(typeIs(1, p9))
312 .legalForCartesianProduct(allowedVectorTypes, allPtrs)
313 .legalForCartesianProduct(allPtrs, allPtrs)
314 .legalIf(isScalar(0))
315 .custom();
316
317 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS,
318 G_BITREVERSE, G_SADDSAT, G_UADDSAT, G_SSUBSAT,
319 G_USUBSAT, G_SCMP, G_UCMP})
320 .legalFor(allIntScalarsAndVectors)
321 .legalIf(extendedScalarsAndVectors);
322
323 getActionDefinitionsBuilder(G_STRICT_FLDEXP)
324 .legalForCartesianProduct(allFloatScalarsAndVectors, allIntScalars);
325
326 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
327 .legalForCartesianProduct(allIntScalarsAndVectors,
328 allFloatScalarsAndVectors);
329
330 getActionDefinitionsBuilder({G_FPTOSI_SAT, G_FPTOUI_SAT})
331 .legalForCartesianProduct(allIntScalarsAndVectors,
332 allFloatScalarsAndVectors);
333
334 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
335 .legalForCartesianProduct(allFloatScalarsAndVectors,
336 allScalarsAndVectors);
337
339 .legalForCartesianProduct(allIntScalarsAndVectors)
340 .legalIf(extendedScalarsAndVectorsProduct);
341
342 // Extensions.
343 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
344 .legalForCartesianProduct(allScalarsAndVectors)
345 .legalIf(extendedScalarsAndVectorsProduct);
346
348 .legalFor(allPtrsScalarsAndVectors)
349 .legalIf(extendedPtrsScalarsAndVectors);
350
352 all(typeInSet(0, allPtrsScalarsAndVectors),
353 typeInSet(1, allPtrsScalarsAndVectors)));
354
355 getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE})
356 .legalFor({s1, s128})
357 .legalFor(allFloatAndIntScalarsAndPtrs)
358 .legalFor(allowedVectorTypes)
362 0, ElementCount::getFixed(MaxVectorSize)));
363
364 getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
365
367 .legalForCartesianProduct(allPtrs, allIntScalars)
368 .legalIf(
369 all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts)));
371 .legalForCartesianProduct(allIntScalars, allPtrs)
372 .legalIf(
373 all(typeOfExtendedScalars(0, IsExtendedInts), typeInSet(1, allPtrs)));
375 .legalForCartesianProduct(allPtrs, allIntScalars)
376 .legalIf(
377 all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts)));
378
379 // ST.canDirectlyComparePointers() for pointer args is supported in
380 // legalizeCustom().
383 all(typeIs(0, p9), typeInSet(1, allPtrs), typeIsNot(1, p9)),
384 all(typeInSet(0, allPtrs), typeIsNot(0, p9), typeIs(1, p9))))
385 .customIf(all(typeInSet(0, allBoolScalarsAndVectors),
386 typeInSet(1, allPtrsScalarsAndVectors)));
387
389 all(typeInSet(0, allBoolScalarsAndVectors),
390 typeInSet(1, allFloatScalarsAndVectors)));
391
392 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
393 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
394 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
395 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
396 .legalForCartesianProduct(allIntScalars, allPtrs);
397
399 {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
400 .legalForCartesianProduct(allFloatScalarsAndF16Vector2AndVector4s,
401 allPtrs);
402
403 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
404 .legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allPtrs);
405
406 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
407 // TODO: add proper legalization rules.
408 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
409
411 {G_UADDO, G_SADDO, G_USUBO, G_SSUBO, G_UMULO, G_SMULO})
412 .alwaysLegal();
413
414 getActionDefinitionsBuilder({G_LROUND, G_LLROUND})
415 .legalForCartesianProduct(allFloatScalarsAndVectors,
416 allIntScalarsAndVectors);
417
418 // FP conversions.
419 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
420 .legalForCartesianProduct(allFloatScalarsAndVectors);
421
422 // Pointer-handling.
423 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
424
425 getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor(allPtrs);
426
427 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
428 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
429
431 allFloatScalarsAndVectors, {s32, v2s32, v3s32, v4s32, v8s32, v16s32});
432
433 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
434 // tighten these requirements. Many of these math functions are only legal on
435 // specific bitwidths, so they are not selectable for
436 // allFloatScalarsAndVectors.
437 getActionDefinitionsBuilder({G_STRICT_FSQRT,
438 G_FPOW,
439 G_FEXP,
440 G_FMODF,
441 G_FEXP2,
442 G_FLOG,
443 G_FLOG2,
444 G_FLOG10,
445 G_FABS,
446 G_FMINNUM,
447 G_FMAXNUM,
448 G_FCEIL,
449 G_FCOS,
450 G_FSIN,
451 G_FTAN,
452 G_FACOS,
453 G_FASIN,
454 G_FATAN,
455 G_FATAN2,
456 G_FCOSH,
457 G_FSINH,
458 G_FTANH,
459 G_FSQRT,
460 G_FFLOOR,
461 G_FRINT,
462 G_FNEARBYINT,
463 G_INTRINSIC_ROUND,
464 G_INTRINSIC_TRUNC,
465 G_FMINIMUM,
466 G_FMAXIMUM,
467 G_INTRINSIC_ROUNDEVEN})
468 .legalFor(allFloatScalarsAndVectors);
469
470 getActionDefinitionsBuilder(G_FCOPYSIGN)
471 .legalForCartesianProduct(allFloatScalarsAndVectors,
472 allFloatScalarsAndVectors);
473
475 allFloatScalarsAndVectors, allIntScalarsAndVectors);
476
477 if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
479 {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
480 .legalForCartesianProduct(allIntScalarsAndVectors,
481 allIntScalarsAndVectors);
482
483 // Struct return types become a single scalar, so cannot easily legalize.
484 getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
485 }
486
487 getActionDefinitionsBuilder(G_IS_FPCLASS).custom();
488
490 verify(*ST.getInstrInfo());
491}
492
495 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
496 Register DstReg = MI.getOperand(0).getReg();
497 Register SrcReg = MI.getOperand(1).getReg();
498 Register IdxReg = MI.getOperand(2).getReg();
499
500 MIRBuilder
501 .buildIntrinsic(Intrinsic::spv_extractelt, ArrayRef<Register>{DstReg})
502 .addUse(SrcReg)
503 .addUse(IdxReg);
504 MI.eraseFromParent();
505 return true;
506}
507
510 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
511 Register DstReg = MI.getOperand(0).getReg();
512 Register SrcReg = MI.getOperand(1).getReg();
513 Register ValReg = MI.getOperand(2).getReg();
514 Register IdxReg = MI.getOperand(3).getReg();
515
516 MIRBuilder
517 .buildIntrinsic(Intrinsic::spv_insertelt, ArrayRef<Register>{DstReg})
518 .addUse(SrcReg)
519 .addUse(ValReg)
520 .addUse(IdxReg);
521 MI.eraseFromParent();
522 return true;
523}
524
526 LegalizerHelper &Helper,
529 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
530 MRI.setRegClass(ConvReg, GR->getRegClass(SpvType));
531 GR->assignSPIRVTypeToVReg(SpvType, ConvReg, Helper.MIRBuilder.getMF());
532 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
533 .addDef(ConvReg)
534 .addUse(Reg);
535 return ConvReg;
536}
537
540 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
541 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
542 Register ValReg = MI.getOperand(0).getReg();
543 Register PtrReg = MI.getOperand(1).getReg();
544 LLT ValTy = MRI.getType(ValReg);
545
546 assert(ValTy.isVector() && "Expected vector store");
547
548 SmallVector<Register, 8> SplitRegs;
549 LLT EltTy = ValTy.getElementType();
550 unsigned NumElts = ValTy.getNumElements();
551
552 for (unsigned i = 0; i < NumElts; ++i)
553 SplitRegs.push_back(MRI.createGenericVirtualRegister(EltTy));
554
555 MIRBuilder.buildUnmerge(SplitRegs, ValReg);
556
557 LLT PtrTy = MRI.getType(PtrReg);
558 auto Zero = MIRBuilder.buildConstant(LLT::scalar(32), 0);
559
560 for (unsigned i = 0; i < NumElts; ++i) {
561 auto Idx = MIRBuilder.buildConstant(LLT::scalar(32), i);
562 Register EltPtr = MRI.createGenericVirtualRegister(PtrTy);
563
564 MIRBuilder.buildIntrinsic(Intrinsic::spv_gep, ArrayRef<Register>{EltPtr})
565 .addImm(1) // InBounds
566 .addUse(PtrReg)
567 .addUse(Zero.getReg(0))
568 .addUse(Idx.getReg(0));
569
570 MachinePointerInfo EltPtrInfo;
571 Align EltAlign = Align(1);
572 if (!MI.memoperands_empty()) {
573 MachineMemOperand *MMO = *MI.memoperands_begin();
574 EltPtrInfo =
575 MMO->getPointerInfo().getWithOffset(i * EltTy.getSizeInBytes());
576 EltAlign = commonAlignment(MMO->getAlign(), i * EltTy.getSizeInBytes());
577 }
578
579 MIRBuilder.buildStore(SplitRegs[i], EltPtr, EltPtrInfo, EltAlign);
580 }
581
582 MI.eraseFromParent();
583 return true;
584}
585
588 LostDebugLocObserver &LocObserver) const {
589 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
590 switch (MI.getOpcode()) {
591 default:
592 // TODO: implement legalization for other opcodes.
593 return true;
594 case TargetOpcode::G_BITCAST:
595 return legalizeBitcast(Helper, MI);
596 case TargetOpcode::G_EXTRACT_VECTOR_ELT:
597 return legalizeExtractVectorElt(Helper, MI, GR);
598 case TargetOpcode::G_INSERT_VECTOR_ELT:
599 return legalizeInsertVectorElt(Helper, MI, GR);
600 case TargetOpcode::G_INTRINSIC:
601 case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
602 return legalizeIntrinsic(Helper, MI);
603 case TargetOpcode::G_IS_FPCLASS:
604 return legalizeIsFPClass(Helper, MI, LocObserver);
605 case TargetOpcode::G_ICMP: {
606 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
607 auto &Op0 = MI.getOperand(2);
608 auto &Op1 = MI.getOperand(3);
609 Register Reg0 = Op0.getReg();
610 Register Reg1 = Op1.getReg();
612 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
613 if ((!ST->canDirectlyComparePointers() ||
615 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
616 LLT ConvT = LLT::scalar(ST->getPointerSize());
617 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
618 ST->getPointerSize());
619 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(
620 LLVMTy, Helper.MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
621 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
622 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
623 }
624 return true;
625 }
626 case TargetOpcode::G_STORE:
627 return legalizeStore(Helper, MI, GR);
628 }
629}
630
631static bool needsVectorLegalization(const LLT &Ty, const SPIRVSubtarget &ST) {
632 if (!Ty.isVector())
633 return false;
634 unsigned NumElements = Ty.getNumElements();
635 unsigned MaxVectorSize = ST.isShader() ? 4 : 16;
636 return (NumElements > 4 && !isPowerOf2_32(NumElements)) ||
637 NumElements > MaxVectorSize;
638}
639
642 LLVM_DEBUG(dbgs() << "Found a bitcast instruction\n");
643 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
644 MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
645 const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
646
647 Register DstReg = MI.getOperand(0).getReg();
648 Register SrcReg = MI.getOperand(2).getReg();
649 LLT DstTy = MRI.getType(DstReg);
650 LLT SrcTy = MRI.getType(SrcReg);
651
652 // If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to
653 // allow using the generic legalization rules.
654 if (needsVectorLegalization(DstTy, ST) ||
655 needsVectorLegalization(SrcTy, ST)) {
656 LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n");
657 MIRBuilder.buildBitcast(DstReg, SrcReg);
658 MI.eraseFromParent();
659 }
660 return true;
661}
662
665 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
666 MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
667 const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
668
669 Register DstReg = MI.getOperand(0).getReg();
670 LLT DstTy = MRI.getType(DstReg);
671
672 if (needsVectorLegalization(DstTy, ST)) {
673 Register SrcReg = MI.getOperand(2).getReg();
674 Register ValReg = MI.getOperand(3).getReg();
675 LLT SrcTy = MRI.getType(SrcReg);
676 MachineOperand &IdxOperand = MI.getOperand(4);
677
678 if (getImm(IdxOperand, &MRI)) {
679 uint64_t IdxVal = foldImm(IdxOperand, &MRI);
680 if (IdxVal < SrcTy.getNumElements()) {
682 SPIRVType *ElementType =
684 LLT ElementLLTTy = GR->getRegType(ElementType);
685 for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) {
686 Register Reg = MRI.createGenericVirtualRegister(ElementLLTTy);
687 MRI.setRegClass(Reg, GR->getRegClass(ElementType));
688 GR->assignSPIRVTypeToVReg(ElementType, Reg, *MI.getMF());
689 Regs.push_back(Reg);
690 }
691 MIRBuilder.buildUnmerge(Regs, SrcReg);
692 Regs[IdxVal] = ValReg;
693 MIRBuilder.buildBuildVector(DstReg, Regs);
694 MI.eraseFromParent();
695 return true;
696 }
697 }
698
699 LLT EltTy = SrcTy.getElementType();
700 Align VecAlign = Helper.getStackTemporaryAlignment(SrcTy);
701
702 MachinePointerInfo PtrInfo;
703 auto StackTemp = Helper.createStackTemporary(
704 TypeSize::getFixed(SrcTy.getSizeInBytes()), VecAlign, PtrInfo);
705
706 MIRBuilder.buildStore(SrcReg, StackTemp, PtrInfo, VecAlign);
707
708 Register IdxReg = IdxOperand.getReg();
709 LLT PtrTy = MRI.getType(StackTemp.getReg(0));
710 Register EltPtr = MRI.createGenericVirtualRegister(PtrTy);
711 auto Zero = MIRBuilder.buildConstant(LLT::scalar(32), 0);
712
713 MIRBuilder.buildIntrinsic(Intrinsic::spv_gep, ArrayRef<Register>{EltPtr})
714 .addImm(1) // InBounds
715 .addUse(StackTemp.getReg(0))
716 .addUse(Zero.getReg(0))
717 .addUse(IdxReg);
718
720 Align EltAlign = Helper.getStackTemporaryAlignment(EltTy);
721 MIRBuilder.buildStore(ValReg, EltPtr, EltPtrInfo, EltAlign);
722
723 MIRBuilder.buildLoad(DstReg, StackTemp, PtrInfo, VecAlign);
724 MI.eraseFromParent();
725 return true;
726 }
727 return true;
728}
729
732 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
733 MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
734 const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
735
736 Register SrcReg = MI.getOperand(2).getReg();
737 LLT SrcTy = MRI.getType(SrcReg);
738
739 if (needsVectorLegalization(SrcTy, ST)) {
740 Register DstReg = MI.getOperand(0).getReg();
741 MachineOperand &IdxOperand = MI.getOperand(3);
742
743 if (getImm(IdxOperand, &MRI)) {
744 uint64_t IdxVal = foldImm(IdxOperand, &MRI);
745 if (IdxVal < SrcTy.getNumElements()) {
746 LLT DstTy = MRI.getType(DstReg);
748 SPIRVType *DstSpvTy = GR->getSPIRVTypeForVReg(DstReg);
749 for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) {
750 if (I == IdxVal) {
751 Regs.push_back(DstReg);
752 } else {
753 Register Reg = MRI.createGenericVirtualRegister(DstTy);
754 MRI.setRegClass(Reg, GR->getRegClass(DstSpvTy));
755 GR->assignSPIRVTypeToVReg(DstSpvTy, Reg, *MI.getMF());
756 Regs.push_back(Reg);
757 }
758 }
759 MIRBuilder.buildUnmerge(Regs, SrcReg);
760 MI.eraseFromParent();
761 return true;
762 }
763 }
764
765 LLT EltTy = SrcTy.getElementType();
766 Align VecAlign = Helper.getStackTemporaryAlignment(SrcTy);
767
768 MachinePointerInfo PtrInfo;
769 auto StackTemp = Helper.createStackTemporary(
770 TypeSize::getFixed(SrcTy.getSizeInBytes()), VecAlign, PtrInfo);
771
772 // Set the type of StackTemp to a pointer to an array of the element type.
773 SPIRVType *SpvSrcTy = GR->getSPIRVTypeForVReg(SrcReg);
774 SPIRVType *EltSpvTy = GR->getScalarOrVectorComponentType(SpvSrcTy);
775 const Type *LLVMEltTy = GR->getTypeForSPIRVType(EltSpvTy);
776 const Type *LLVMArrTy =
777 ArrayType::get(const_cast<Type *>(LLVMEltTy), SrcTy.getNumElements());
778 SPIRVType *ArrSpvTy = GR->getOrCreateSPIRVType(
779 LLVMArrTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
780 SPIRVType *PtrToArrSpvTy = GR->getOrCreateSPIRVPointerType(
781 ArrSpvTy, MIRBuilder, SPIRV::StorageClass::Function);
782
783 Register StackReg = StackTemp.getReg(0);
784 MRI.setRegClass(StackReg, GR->getRegClass(PtrToArrSpvTy));
785 GR->assignSPIRVTypeToVReg(PtrToArrSpvTy, StackReg, *MI.getMF());
786
787 MIRBuilder.buildStore(SrcReg, StackTemp, PtrInfo, VecAlign);
788
789 Register IdxReg = IdxOperand.getReg();
790 LLT PtrTy = MRI.getType(StackTemp.getReg(0));
791 Register EltPtr = MRI.createGenericVirtualRegister(PtrTy);
792 auto Zero = MIRBuilder.buildConstant(LLT::scalar(32), 0);
793
794 MIRBuilder.buildIntrinsic(Intrinsic::spv_gep, ArrayRef<Register>{EltPtr})
795 .addImm(1) // InBounds
796 .addUse(StackTemp.getReg(0))
797 .addUse(Zero.getReg(0))
798 .addUse(IdxReg);
799
801 Align EltAlign = Helper.getStackTemporaryAlignment(EltTy);
802 MIRBuilder.buildLoad(DstReg, EltPtr, EltPtrInfo, EltAlign);
803
804 MI.eraseFromParent();
805 return true;
806 }
807 return true;
808}
809
811 MachineInstr &MI) const {
812 LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI);
813 auto IntrinsicID = cast<GIntrinsic>(MI).getIntrinsicID();
814 switch (IntrinsicID) {
815 case Intrinsic::spv_bitcast:
816 return legalizeSpvBitcast(Helper, MI, GR);
817 case Intrinsic::spv_insertelt:
818 return legalizeSpvInsertElt(Helper, MI, GR);
819 case Intrinsic::spv_extractelt:
820 return legalizeSpvExtractElt(Helper, MI, GR);
821 }
822 return true;
823}
824
825bool SPIRVLegalizerInfo::legalizeBitcast(LegalizerHelper &Helper,
826 MachineInstr &MI) const {
827 // Once the G_BITCAST is using vectors that are allowed, we turn it back into
828 // an spv_bitcast to avoid verifier problems when the register types are the
829 // same for the source and the result. Note that the SPIR-V types associated
830 // with the bitcast can be different even if the register types are the same.
831 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
832 Register DstReg = MI.getOperand(0).getReg();
833 Register SrcReg = MI.getOperand(1).getReg();
834 SmallVector<Register, 1> DstRegs = {DstReg};
835 MIRBuilder.buildIntrinsic(Intrinsic::spv_bitcast, DstRegs).addUse(SrcReg);
836 MI.eraseFromParent();
837 return true;
838}
839
840// Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted
841// to ensure that all instructions created during the lowering have SPIR-V types
842// assigned to them.
843bool SPIRVLegalizerInfo::legalizeIsFPClass(
845 LostDebugLocObserver &LocObserver) const {
846 auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
847 FPClassTest Mask = static_cast<FPClassTest>(MI.getOperand(2).getImm());
848
849 auto &MIRBuilder = Helper.MIRBuilder;
850 auto &MF = MIRBuilder.getMF();
851 MachineRegisterInfo &MRI = MF.getRegInfo();
852
853 Type *LLVMDstTy =
854 IntegerType::get(MIRBuilder.getContext(), DstTy.getScalarSizeInBits());
855 if (DstTy.isVector())
856 LLVMDstTy = VectorType::get(LLVMDstTy, DstTy.getElementCount());
857 SPIRVType *SPIRVDstTy = GR->getOrCreateSPIRVType(
858 LLVMDstTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
859 /*EmitIR*/ true);
860
861 unsigned BitSize = SrcTy.getScalarSizeInBits();
862 const fltSemantics &Semantics = getFltSemanticForLLT(SrcTy.getScalarType());
863
864 LLT IntTy = LLT::scalar(BitSize);
865 Type *LLVMIntTy = IntegerType::get(MIRBuilder.getContext(), BitSize);
866 if (SrcTy.isVector()) {
867 IntTy = LLT::vector(SrcTy.getElementCount(), IntTy);
868 LLVMIntTy = VectorType::get(LLVMIntTy, SrcTy.getElementCount());
869 }
870 SPIRVType *SPIRVIntTy = GR->getOrCreateSPIRVType(
871 LLVMIntTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
872 /*EmitIR*/ true);
873
874 // Clang doesn't support capture of structured bindings:
875 LLT DstTyCopy = DstTy;
876 const auto assignSPIRVTy = [&](MachineInstrBuilder &&MI) {
877 // Assign this MI's (assumed only) destination to one of the two types we
878 // expect: either the G_IS_FPCLASS's destination type, or the integer type
879 // bitcast from the source type.
880 LLT MITy = MRI.getType(MI.getReg(0));
881 assert((MITy == IntTy || MITy == DstTyCopy) &&
882 "Unexpected LLT type while lowering G_IS_FPCLASS");
883 auto *SPVTy = MITy == IntTy ? SPIRVIntTy : SPIRVDstTy;
884 GR->assignSPIRVTypeToVReg(SPVTy, MI.getReg(0), MF);
885 return MI;
886 };
887
888 // Helper to build and assign a constant in one go
889 const auto buildSPIRVConstant = [&](LLT Ty, auto &&C) -> MachineInstrBuilder {
890 if (!Ty.isFixedVector())
891 return assignSPIRVTy(MIRBuilder.buildConstant(Ty, C));
892 auto ScalarC = MIRBuilder.buildConstant(Ty.getScalarType(), C);
893 assert((Ty == IntTy || Ty == DstTyCopy) &&
894 "Unexpected LLT type while lowering constant for G_IS_FPCLASS");
895 SPIRVType *VecEltTy = GR->getOrCreateSPIRVType(
896 (Ty == IntTy ? LLVMIntTy : LLVMDstTy)->getScalarType(), MIRBuilder,
897 SPIRV::AccessQualifier::ReadWrite,
898 /*EmitIR*/ true);
899 GR->assignSPIRVTypeToVReg(VecEltTy, ScalarC.getReg(0), MF);
900 return assignSPIRVTy(MIRBuilder.buildSplatBuildVector(Ty, ScalarC));
901 };
902
903 if (Mask == fcNone) {
904 MIRBuilder.buildCopy(DstReg, buildSPIRVConstant(DstTy, 0));
905 MI.eraseFromParent();
906 return true;
907 }
908 if (Mask == fcAllFlags) {
909 MIRBuilder.buildCopy(DstReg, buildSPIRVConstant(DstTy, 1));
910 MI.eraseFromParent();
911 return true;
912 }
913
914 // Note that rather than creating a COPY here (between a floating-point and
915 // integer type of the same size) we create a SPIR-V bitcast immediately. We
916 // can't create a G_BITCAST because the LLTs are the same, and we can't seem
917 // to correctly lower COPYs to SPIR-V bitcasts at this moment.
918 Register ResVReg = MRI.createGenericVirtualRegister(IntTy);
919 MRI.setRegClass(ResVReg, GR->getRegClass(SPIRVIntTy));
920 GR->assignSPIRVTypeToVReg(SPIRVIntTy, ResVReg, Helper.MIRBuilder.getMF());
921 auto AsInt = MIRBuilder.buildInstr(SPIRV::OpBitcast)
922 .addDef(ResVReg)
923 .addUse(GR->getSPIRVTypeID(SPIRVIntTy))
924 .addUse(SrcReg);
925 AsInt = assignSPIRVTy(std::move(AsInt));
926
927 // Various masks.
928 APInt SignBit = APInt::getSignMask(BitSize);
929 APInt ValueMask = APInt::getSignedMaxValue(BitSize); // All bits but sign.
930 APInt Inf = APFloat::getInf(Semantics).bitcastToAPInt(); // Exp and int bit.
931 APInt ExpMask = Inf;
932 APInt AllOneMantissa = APFloat::getLargest(Semantics).bitcastToAPInt() & ~Inf;
933 APInt QNaNBitMask =
934 APInt::getOneBitSet(BitSize, AllOneMantissa.getActiveBits() - 1);
935 APInt InversionMask = APInt::getAllOnes(DstTy.getScalarSizeInBits());
936
937 auto SignBitC = buildSPIRVConstant(IntTy, SignBit);
938 auto ValueMaskC = buildSPIRVConstant(IntTy, ValueMask);
939 auto InfC = buildSPIRVConstant(IntTy, Inf);
940 auto ExpMaskC = buildSPIRVConstant(IntTy, ExpMask);
941 auto ZeroC = buildSPIRVConstant(IntTy, 0);
942
943 auto Abs = assignSPIRVTy(MIRBuilder.buildAnd(IntTy, AsInt, ValueMaskC));
944 auto Sign = assignSPIRVTy(
945 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_NE, DstTy, AsInt, Abs));
946
947 auto Res = buildSPIRVConstant(DstTy, 0);
948
949 const auto appendToRes = [&](MachineInstrBuilder &&ToAppend) {
950 Res = assignSPIRVTy(
951 MIRBuilder.buildOr(DstTyCopy, Res, assignSPIRVTy(std::move(ToAppend))));
952 };
953
954 // Tests that involve more than one class should be processed first.
955 if ((Mask & fcFinite) == fcFinite) {
956 // finite(V) ==> abs(V) u< exp_mask
957 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, Abs,
958 ExpMaskC));
959 Mask &= ~fcFinite;
960 } else if ((Mask & fcFinite) == fcPosFinite) {
961 // finite(V) && V > 0 ==> V u< exp_mask
962 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, AsInt,
963 ExpMaskC));
964 Mask &= ~fcPosFinite;
965 } else if ((Mask & fcFinite) == fcNegFinite) {
966 // finite(V) && V < 0 ==> abs(V) u< exp_mask && signbit == 1
967 auto Cmp = assignSPIRVTy(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT,
968 DstTy, Abs, ExpMaskC));
969 appendToRes(MIRBuilder.buildAnd(DstTy, Cmp, Sign));
970 Mask &= ~fcNegFinite;
971 }
972
973 if (FPClassTest PartialCheck = Mask & (fcZero | fcSubnormal)) {
974 // fcZero | fcSubnormal => test all exponent bits are 0
975 // TODO: Handle sign bit specific cases
976 // TODO: Handle inverted case
977 if (PartialCheck == (fcZero | fcSubnormal)) {
978 auto ExpBits = assignSPIRVTy(MIRBuilder.buildAnd(IntTy, AsInt, ExpMaskC));
979 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
980 ExpBits, ZeroC));
981 Mask &= ~PartialCheck;
982 }
983 }
984
985 // Check for individual classes.
986 if (FPClassTest PartialCheck = Mask & fcZero) {
987 if (PartialCheck == fcPosZero)
988 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
989 AsInt, ZeroC));
990 else if (PartialCheck == fcZero)
991 appendToRes(
992 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, ZeroC));
993 else // fcNegZero
994 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
995 AsInt, SignBitC));
996 }
997
998 if (FPClassTest PartialCheck = Mask & fcSubnormal) {
999 // issubnormal(V) ==> unsigned(abs(V) - 1) u< (all mantissa bits set)
1000 // issubnormal(V) && V>0 ==> unsigned(V - 1) u< (all mantissa bits set)
1001 auto V = (PartialCheck == fcPosSubnormal) ? AsInt : Abs;
1002 auto OneC = buildSPIRVConstant(IntTy, 1);
1003 auto VMinusOne = MIRBuilder.buildSub(IntTy, V, OneC);
1004 auto SubnormalRes = assignSPIRVTy(
1005 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, VMinusOne,
1006 buildSPIRVConstant(IntTy, AllOneMantissa)));
1007 if (PartialCheck == fcNegSubnormal)
1008 SubnormalRes = MIRBuilder.buildAnd(DstTy, SubnormalRes, Sign);
1009 appendToRes(std::move(SubnormalRes));
1010 }
1011
1012 if (FPClassTest PartialCheck = Mask & fcInf) {
1013 if (PartialCheck == fcPosInf)
1014 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
1015 AsInt, InfC));
1016 else if (PartialCheck == fcInf)
1017 appendToRes(
1018 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, InfC));
1019 else { // fcNegInf
1020 APInt NegInf = APFloat::getInf(Semantics, true).bitcastToAPInt();
1021 auto NegInfC = buildSPIRVConstant(IntTy, NegInf);
1022 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
1023 AsInt, NegInfC));
1024 }
1025 }
1026
1027 if (FPClassTest PartialCheck = Mask & fcNan) {
1028 auto InfWithQnanBitC =
1029 buildSPIRVConstant(IntTy, std::move(Inf) | QNaNBitMask);
1030 if (PartialCheck == fcNan) {
1031 // isnan(V) ==> abs(V) u> int(inf)
1032 appendToRes(
1033 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
1034 } else if (PartialCheck == fcQNan) {
1035 // isquiet(V) ==> abs(V) u>= (unsigned(Inf) | quiet_bit)
1036 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGE, DstTy, Abs,
1037 InfWithQnanBitC));
1038 } else { // fcSNan
1039 // issignaling(V) ==> abs(V) u> unsigned(Inf) &&
1040 // abs(V) u< (unsigned(Inf) | quiet_bit)
1041 auto IsNan = assignSPIRVTy(
1042 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
1043 auto IsNotQnan = assignSPIRVTy(MIRBuilder.buildICmp(
1044 CmpInst::Predicate::ICMP_ULT, DstTy, Abs, InfWithQnanBitC));
1045 appendToRes(MIRBuilder.buildAnd(DstTy, IsNan, IsNotQnan));
1046 }
1047 }
1048
1049 if (FPClassTest PartialCheck = Mask & fcNormal) {
1050 // isnormal(V) ==> (0 u< exp u< max_exp) ==> (unsigned(exp-1) u<
1051 // (max_exp-1))
1052 APInt ExpLSB = ExpMask & ~(ExpMask.shl(1));
1053 auto ExpMinusOne = assignSPIRVTy(
1054 MIRBuilder.buildSub(IntTy, Abs, buildSPIRVConstant(IntTy, ExpLSB)));
1055 APInt MaxExpMinusOne = std::move(ExpMask) - ExpLSB;
1056 auto NormalRes = assignSPIRVTy(
1057 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, ExpMinusOne,
1058 buildSPIRVConstant(IntTy, MaxExpMinusOne)));
1059 if (PartialCheck == fcNegNormal)
1060 NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, Sign);
1061 else if (PartialCheck == fcPosNormal) {
1062 auto PosSign = assignSPIRVTy(MIRBuilder.buildXor(
1063 DstTy, Sign, buildSPIRVConstant(DstTy, InversionMask)));
1064 NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, PosSign);
1065 }
1066 appendToRes(std::move(NormalRes));
1067 }
1068
1069 MIRBuilder.buildCopy(DstReg, Res);
1070 MI.eraseFromParent();
1071 return true;
1072}
unsigned const MachineRegisterInfo * MRI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static void scalarize(Instruction *I, SmallVectorImpl< Instruction * > &Worklist)
Declares convenience wrapper classes for interpreting MachineInstr instances as specific generic oper...
IRTranslator LLVM IR MI
#define I(x, y, z)
Definition MD5.cpp:57
This file declares the MachineIRBuilder class.
Register Reg
Promote Memory to Register
Definition Mem2Reg.cpp:110
ppc ctr loops verify
const SmallVectorImpl< MachineOperand > & Cond
static bool legalizeSpvInsertElt(LegalizerHelper &Helper, MachineInstr &MI, SPIRVGlobalRegistry *GR)
static bool needsVectorLegalization(const LLT &Ty, const SPIRVSubtarget &ST)
static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType, LegalizerHelper &Helper, MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR)
LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts)
static bool legalizeExtractVectorElt(LegalizerHelper &Helper, MachineInstr &MI, SPIRVGlobalRegistry *GR)
static bool legalizeStore(LegalizerHelper &Helper, MachineInstr &MI, SPIRVGlobalRegistry *GR)
static bool legalizeSpvExtractElt(LegalizerHelper &Helper, MachineInstr &MI, SPIRVGlobalRegistry *GR)
static bool legalizeSpvBitcast(LegalizerHelper &Helper, MachineInstr &MI, SPIRVGlobalRegistry *GR)
static bool legalizeInsertVectorElt(LegalizerHelper &Helper, MachineInstr &MI, SPIRVGlobalRegistry *GR)
#define LLVM_DEBUG(...)
Definition Debug.h:114
APInt bitcastToAPInt() const
Definition APFloat.h:1335
static APFloat getLargest(const fltSemantics &Sem, bool Negative=false)
Returns the largest finite number in the given semantics.
Definition APFloat.h:1120
static APFloat getInf(const fltSemantics &Sem, bool Negative=false)
Factory for Positive and Negative Infinity.
Definition APFloat.h:1080
static APInt getAllOnes(unsigned numBits)
Return an APInt of a specified width with all bits set.
Definition APInt.h:235
static APInt getSignMask(unsigned BitWidth)
Get the SignMask for a specific bit width.
Definition APInt.h:230
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1513
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:210
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:874
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
static LLVM_ABI ArrayType * get(Type *ElementType, uint64_t NumElements)
This static method is the primary way to construct an ArrayType.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:700
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:699
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:701
@ ICMP_NE
not equal
Definition InstrTypes.h:698
static constexpr ElementCount getFixed(ScalarTy MinVal)
Definition TypeSize.h:309
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:318
static constexpr LLT vector(ElementCount EC, unsigned ScalarSizeInBits)
Get a low-level vector of some number of elements and element width.
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
constexpr bool isValid() const
constexpr uint16_t getNumElements() const
Returns the number of elements in a vector LLT.
static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits)
Get a low-level pointer in the given address space.
constexpr LLT getElementType() const
Returns the vector's element type. Only valid for vector types.
constexpr unsigned getAddressSpace() const
static constexpr LLT fixed_vector(unsigned NumElements, unsigned ScalarSizeInBits)
Get a low-level fixed-width vector of some number of elements and element width.
constexpr bool isPointerOrPointerVector() const
constexpr bool isFixedVector() const
Returns true if the LLT is a fixed vector.
constexpr LLT getScalarType() const
constexpr TypeSize getSizeInBytes() const
Returns the total size of the type in bytes, i.e.
LLVM_ABI void computeTables()
Compute any ancillary tables needed to quickly decide how an operation should be handled.
LegalizeRuleSet & legalFor(std::initializer_list< LLT > Types)
The instruction is legal when type index 0 is any type in the given list.
LegalizeRuleSet & fewerElementsIf(LegalityPredicate Predicate, LegalizeMutation Mutation)
Remove elements to reach the type selected by the mutation if the predicate is true.
LegalizeRuleSet & moreElementsToNextPow2(unsigned TypeIdx)
Add more elements to the vector to reach the next power of two.
LegalizeRuleSet & lower()
The instruction is lowered.
LegalizeRuleSet & scalarizeIf(LegalityPredicate Predicate, unsigned TypeIdx)
LegalizeRuleSet & lowerIf(LegalityPredicate Predicate)
The instruction is lowered if predicate is true.
LegalizeRuleSet & custom()
Unconditionally custom lower.
LegalizeRuleSet & unsupportedIf(LegalityPredicate Predicate)
LegalizeRuleSet & alwaysLegal()
LegalizeRuleSet & customIf(LegalityPredicate Predicate)
LegalizeRuleSet & scalarize(unsigned TypeIdx)
LegalizeRuleSet & legalForCartesianProduct(std::initializer_list< LLT > Types)
The instruction is legal when type indexes 0 and 1 are both in the given list.
LegalizeRuleSet & legalIf(LegalityPredicate Predicate)
The instruction is legal if predicate is true.
LegalizeRuleSet & customFor(std::initializer_list< LLT > Types)
LLVM_ABI MachineInstrBuilder createStackTemporary(TypeSize Bytes, Align Alignment, MachinePointerInfo &PtrInfo)
Create a stack temporary based on the size in bytes and the alignment.
MachineIRBuilder & MIRBuilder
Expose MIRBuilder so clients can set their own RecordInsertInstruction functions.
LLVM_ABI Align getStackTemporaryAlignment(LLT Type, Align MinAlign=Align()) const
Return the alignment to use for a stack temporary object with the given type.
LegalizeRuleSet & getActionDefinitionsBuilder(unsigned Opcode)
Get the action definition builder for the given opcode.
const LegacyLegalizerInfo & getLegacyLegalizerInfo() const
Helper class to build MachineInstr.
LLVMContext & getContext() const
MachineInstrBuilder buildUnmerge(ArrayRef< LLT > Res, const SrcOp &Op)
Build and insert Res0, ... = G_UNMERGE_VALUES Op.
MachineInstrBuilder buildAnd(const DstOp &Dst, const SrcOp &Src0, const SrcOp &Src1)
Build and insert Res = G_AND Op0, Op1.
MachineInstrBuilder buildICmp(CmpInst::Predicate Pred, const DstOp &Res, const SrcOp &Op0, const SrcOp &Op1, std::optional< unsigned > Flags=std::nullopt)
Build and insert a Res = G_ICMP Pred, Op0, Op1.
MachineInstrBuilder buildSub(const DstOp &Dst, const SrcOp &Src0, const SrcOp &Src1, std::optional< unsigned > Flags=std::nullopt)
Build and insert Res = G_SUB Op0, Op1.
MachineInstrBuilder buildIntrinsic(Intrinsic::ID ID, ArrayRef< Register > Res, bool HasSideEffects, bool isConvergent)
Build and insert a G_INTRINSIC instruction.
MachineInstrBuilder buildSplatBuildVector(const DstOp &Res, const SrcOp &Src)
Build and insert Res = G_BUILD_VECTOR with Src replicated to fill the number of elements.
MachineInstrBuilder buildBuildVector(const DstOp &Res, ArrayRef< Register > Ops)
Build and insert Res = G_BUILD_VECTOR Op0, ...
MachineInstrBuilder buildLoad(const DstOp &Res, const SrcOp &Addr, MachineMemOperand &MMO)
Build and insert Res = G_LOAD Addr, MMO.
MachineInstrBuilder buildStore(const SrcOp &Val, const SrcOp &Addr, MachineMemOperand &MMO)
Build and insert G_STORE Val, Addr, MMO.
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
MachineFunction & getMF()
Getter for the function we currently build.
MachineInstrBuilder buildBitcast(const DstOp &Dst, const SrcOp &Src)
Build and insert Dst = G_BITCAST Src.
MachineRegisterInfo * getMRI()
Getter for MRI.
MachineInstrBuilder buildOr(const DstOp &Dst, const SrcOp &Src0, const SrcOp &Src1, std::optional< unsigned > Flags=std::nullopt)
Build and insert Res = G_OR Op0, Op1.
MachineInstrBuilder buildCopy(const DstOp &Res, const SrcOp &Op)
Build and insert Res = COPY Op.
MachineInstrBuilder buildXor(const DstOp &Dst, const SrcOp &Src0, const SrcOp &Src1)
Build and insert Res = G_XOR Op0, Op1.
virtual MachineInstrBuilder buildConstant(const DstOp &Res, const ConstantInt &Val)
Build and insert Res = G_CONSTANT Val.
const MachineInstrBuilder & addUse(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register use operand.
const MachineInstrBuilder & addDef(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
A description of a memory reference used in the backend.
const MachinePointerInfo & getPointerInfo() const
LLVM_ABI Align getAlign() const
Return the minimum known alignment in bytes of the actual memory reference.
MachineOperand class - Representation of each machine instruction operand.
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
Wrapper class representing virtual and physical registers.
Definition Register.h:20
SPIRVType * getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg, const MachineFunction &MF)
const Type * getTypeForSPIRVType(const SPIRVType *Ty) const
SPIRVType * getOrCreateSPIRVType(const Type *Type, MachineInstr &I, SPIRV::AccessQualifier::AccessQualifier AQ, bool EmitIR)
SPIRVType * getOrCreateSPIRVPointerType(const Type *BaseType, MachineIRBuilder &MIRBuilder, SPIRV::StorageClass::StorageClass SC)
SPIRVType * getScalarOrVectorComponentType(Register VReg) const
const TargetRegisterClass * getRegClass(SPIRVType *SpvType) const
LLT getRegType(SPIRVType *SpvType) const
SPIRVLegalizerInfo(const SPIRVSubtarget &ST)
bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const override
Called for instructions with the Custom LegalizationAction.
bool legalizeIntrinsic(LegalizerHelper &Helper, MachineInstr &MI) const override
SPIRVGlobalRegistry * getSPIRVGlobalRegistry() const
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
static constexpr TypeSize getFixed(ScalarTy ExactSize)
Definition TypeSize.h:343
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
LLVM_ABI LegalityPredicate isScalar(unsigned TypeIdx)
True iff the specified type index is a scalar.
LLVM_ABI LegalityPredicate numElementsNotPow2(unsigned TypeIdx)
True iff the specified type index is a vector whose element count is not a power of 2.
LLVM_ABI LegalityPredicate vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx, unsigned Size)
True iff the specified type index is a vector with a number of elements that's less than or equal to ...
LLVM_ABI LegalityPredicate typeInSet(unsigned TypeIdx, std::initializer_list< LLT > TypesInit)
True iff the given type index is one of the specified types.
LLVM_ABI LegalityPredicate vectorElementCountIsGreaterThan(unsigned TypeIdx, unsigned Size)
True iff the specified type index is a vector with a number of elements that's greater than the given...
Predicate any(Predicate P0, Predicate P1)
True iff P0 or P1 are true.
LegalityPredicate typeIsNot(unsigned TypeIdx, LLT Type)
True iff the given type index is not the specified type.
Predicate all(Predicate P0, Predicate P1)
True iff P0 and P1 are true.
LLVM_ABI LegalityPredicate typeIs(unsigned TypeIdx, LLT TypesInit)
True iff the given type index is the specified type.
LLVM_ABI LegalizeMutation changeElementCountTo(unsigned TypeIdx, unsigned FromTypeIdx)
Keep the same scalar or element type as TypeIdx, but take the number of elements from FromTypeIdx.
LLVM_ABI LegalizeMutation changeElementSizeTo(unsigned TypeIdx, unsigned FromTypeIdx)
Change the scalar size or element size to have the same scalar size as type index FromIndex.
Invariant opcodes: All instruction sets have these as their low opcodes.
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
LLVM_ABI const llvm::fltSemantics & getFltSemanticForLLT(LLT Ty)
Get the appropriate floating point arithmetic semantic based on the bit size of the given scalar LLT.
std::function< bool(const LegalityQuery &)> LegalityPredicate
MachineInstr * getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI)
constexpr bool isPowerOf2_32(uint32_t Value)
Return true if the argument is a power of two > 0.
Definition MathExtras.h:279
FPClassTest
Floating-point class tests, supported by 'is_fpclass' intrinsic.
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
const MachineInstr SPIRVType
const std::set< unsigned > & getTypeFoldingSupportedOpcodes()
int64_t foldImm(const MachineOperand &MO, const MachineRegisterInfo *MRI)
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition Alignment.h:201
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition Alignment.h:39
The LegalityQuery object bundles together all the information that's needed to decide whether a given...
This class contains a discriminated union of information about pointers in memory operands,...
MachinePointerInfo getWithOffset(int64_t O) const