LLVM 22.0.0git
SPIRVISelLowering.cpp
Go to the documentation of this file.
1//===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the SPIRVTargetLowering class.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVISelLowering.h"
14#include "SPIRV.h"
15#include "SPIRVInstrInfo.h"
17#include "SPIRVRegisterInfo.h"
18#include "SPIRVSubtarget.h"
21#include "llvm/IR/IntrinsicsSPIRV.h"
22
23#define DEBUG_TYPE "spirv-lower"
24
25using namespace llvm;
26
30
31// Returns true of the types logically match, as defined in
32// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCopyLogical.
33static bool typesLogicallyMatch(const SPIRVType *Ty1, const SPIRVType *Ty2,
35 if (Ty1->getOpcode() != Ty2->getOpcode())
36 return false;
37
38 if (Ty1->getNumOperands() != Ty2->getNumOperands())
39 return false;
40
41 if (Ty1->getOpcode() == SPIRV::OpTypeArray) {
42 // Array must have the same size.
43 if (Ty1->getOperand(2).getReg() != Ty2->getOperand(2).getReg())
44 return false;
45
46 SPIRVType *ElemType1 = GR.getSPIRVTypeForVReg(Ty1->getOperand(1).getReg());
47 SPIRVType *ElemType2 = GR.getSPIRVTypeForVReg(Ty2->getOperand(1).getReg());
48 return ElemType1 == ElemType2 ||
49 typesLogicallyMatch(ElemType1, ElemType2, GR);
50 }
51
52 if (Ty1->getOpcode() == SPIRV::OpTypeStruct) {
53 for (unsigned I = 1; I < Ty1->getNumOperands(); I++) {
54 SPIRVType *ElemType1 =
56 SPIRVType *ElemType2 =
58 if (ElemType1 != ElemType2 &&
59 !typesLogicallyMatch(ElemType1, ElemType2, GR))
60 return false;
61 }
62 return true;
63 }
64 return false;
65}
66
68 LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
69 // This code avoids CallLowering fail inside getVectorTypeBreakdown
70 // on v3i1 arguments. Maybe we need to return 1 for all types.
71 // TODO: remove it once this case is supported by the default implementation.
72 if (VT.isVector() && VT.getVectorNumElements() == 3 &&
73 (VT.getVectorElementType() == MVT::i1 ||
74 VT.getVectorElementType() == MVT::i8))
75 return 1;
76 if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)
77 return 1;
78 return getNumRegisters(Context, VT);
79}
80
83 EVT VT) const {
84 // This code avoids CallLowering fail inside getVectorTypeBreakdown
85 // on v3i1 arguments. Maybe we need to return i32 for all types.
86 // TODO: remove it once this case is supported by the default implementation.
87 if (VT.isVector() && VT.getVectorNumElements() == 3) {
88 if (VT.getVectorElementType() == MVT::i1)
89 return MVT::v4i1;
90 else if (VT.getVectorElementType() == MVT::i8)
91 return MVT::v4i8;
92 }
93 return getRegisterType(Context, VT);
94}
95
97 const CallInst &I,
99 unsigned Intrinsic) const {
100 unsigned AlignIdx = 3;
101 switch (Intrinsic) {
102 case Intrinsic::spv_load:
103 AlignIdx = 2;
104 [[fallthrough]];
105 case Intrinsic::spv_store: {
106 if (I.getNumOperands() >= AlignIdx + 1) {
107 auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx));
108 Info.align = Align(AlignOp->getZExtValue());
109 }
110 Info.flags = static_cast<MachineMemOperand::Flags>(
111 cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue());
112 Info.memVT = MVT::i64;
113 // TODO: take into account opaque pointers (don't use getElementType).
114 // MVT::getVT(PtrTy->getElementType());
115 return true;
116 break;
117 }
118 default:
119 break;
120 }
121 return false;
122}
123
124std::pair<unsigned, const TargetRegisterClass *>
126 StringRef Constraint,
127 MVT VT) const {
128 const TargetRegisterClass *RC = nullptr;
129 if (Constraint.starts_with("{"))
130 return std::make_pair(0u, RC);
131
132 if (VT.isFloatingPoint())
133 RC = VT.isVector() ? &SPIRV::vfIDRegClass : &SPIRV::fIDRegClass;
134 else if (VT.isInteger())
135 RC = VT.isVector() ? &SPIRV::vIDRegClass : &SPIRV::iIDRegClass;
136 else
137 RC = &SPIRV::iIDRegClass;
138
139 return std::make_pair(0u, RC);
140}
141
143 SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
144 return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
145 ? TypeInst->getOperand(1).getReg()
146 : OpReg;
147}
148
151 Register OpReg, unsigned OpIdx,
152 SPIRVType *NewPtrType) {
153 MachineIRBuilder MIB(I);
154 Register NewReg = createVirtualRegister(NewPtrType, &GR, MRI, MIB.getMF());
155 bool Res = MIB.buildInstr(SPIRV::OpBitcast)
156 .addDef(NewReg)
157 .addUse(GR.getSPIRVTypeID(NewPtrType))
158 .addUse(OpReg)
160 *STI.getRegBankInfo());
161 if (!Res)
162 report_fatal_error("insert validation bitcast: cannot constrain all uses");
163 I.getOperand(OpIdx).setReg(NewReg);
164}
165
167 SPIRVType *OpType, bool ReuseType,
168 SPIRVType *ResType, const Type *ResTy) {
169 SPIRV::StorageClass::StorageClass SC =
170 static_cast<SPIRV::StorageClass::StorageClass>(
171 OpType->getOperand(1).getImm());
172 MachineIRBuilder MIB(I);
173 SPIRVType *NewBaseType =
174 ReuseType ? ResType
176 ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false);
177 return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
178}
179
180// Insert a bitcast before the instruction to keep SPIR-V code valid
181// when there is a type mismatch between results and operand types.
182static void validatePtrTypes(const SPIRVSubtarget &STI,
184 MachineInstr &I, unsigned OpIdx,
185 SPIRVType *ResType, const Type *ResTy = nullptr) {
186 // Get operand type
187 MachineFunction *MF = I.getParent()->getParent();
188 Register OpReg = I.getOperand(OpIdx).getReg();
189 Register OpTypeReg = getTypeReg(MRI, OpReg);
190 SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
191 if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
192 return;
193 // Get operand's pointee type
194 Register ElemTypeReg = OpType->getOperand(2).getReg();
195 SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
196 if (!ElemType)
197 return;
198 // Check if we need a bitcast to make a statement valid
199 bool IsSameMF = MF == ResType->getParent()->getParent();
200 bool IsEqualTypes = IsSameMF ? ElemType == ResType
201 : GR.getTypeForSPIRVType(ElemType) == ResTy;
202 if (IsEqualTypes)
203 return;
204 // There is a type mismatch between results and operand types
205 // and we insert a bitcast before the instruction to keep SPIR-V code valid
206 SPIRVType *NewPtrType =
207 createNewPtrType(GR, I, OpType, IsSameMF, ResType, ResTy);
208 if (!GR.isBitcastCompatible(NewPtrType, OpType))
210 "insert validation bitcast: incompatible result and operand types");
211 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
212}
213
214// Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
215// that doesn't point to OpTypeEvent.
219 MachineInstr &I) {
220 constexpr unsigned OpIdx = 2;
221 MachineFunction *MF = I.getParent()->getParent();
222 Register OpReg = I.getOperand(OpIdx).getReg();
223 Register OpTypeReg = getTypeReg(MRI, OpReg);
224 SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
225 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
226 return;
227 SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
228 if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)
229 return;
230 // Insert a bitcast before the instruction to keep SPIR-V code valid.
231 LLVMContext &Context = MF->getFunction().getContext();
232 SPIRVType *NewPtrType =
233 createNewPtrType(GR, I, OpType, false, nullptr,
234 TargetExtType::get(Context, "spirv.Event"));
235 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
236}
237
241 Register PtrReg = I.getOperand(0).getReg();
242 MachineFunction *MF = I.getParent()->getParent();
243 Register PtrTypeReg = getTypeReg(MRI, PtrReg);
244 SPIRVType *PtrType = GR.getSPIRVTypeForVReg(PtrTypeReg, MF);
245 SPIRVType *PonteeElemType = PtrType ? GR.getPointeeType(PtrType) : nullptr;
246 if (!PonteeElemType || PonteeElemType->getOpcode() == SPIRV::OpTypeVoid ||
247 (PonteeElemType->getOpcode() == SPIRV::OpTypeInt &&
248 PonteeElemType->getOperand(1).getImm() == 8))
249 return;
250 // To keep the code valid a bitcast must be inserted
251 SPIRV::StorageClass::StorageClass SC =
252 static_cast<SPIRV::StorageClass::StorageClass>(
253 PtrType->getOperand(1).getImm());
254 MachineIRBuilder MIB(I);
255 LLVMContext &Context = MF->getFunction().getContext();
256 SPIRVType *NewPtrType =
258 doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
259}
260
264 MachineInstr &I, unsigned OpIdx) {
265 MachineFunction *MF = I.getParent()->getParent();
266 Register OpReg = I.getOperand(OpIdx).getReg();
267 Register OpTypeReg = getTypeReg(MRI, OpReg);
268 SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
269 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
270 return;
271 SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
272 if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct ||
273 ElemType->getNumOperands() != 2)
274 return;
275 // It's a structure-wrapper around another type with a single member field.
276 SPIRVType *MemberType =
277 GR.getSPIRVTypeForVReg(ElemType->getOperand(1).getReg());
278 if (!MemberType)
279 return;
280 unsigned MemberTypeOp = MemberType->getOpcode();
281 if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&
282 MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)
283 return;
284 // It's a structure-wrapper around a valid type. Insert a bitcast before the
285 // instruction to keep SPIR-V code valid.
286 SPIRV::StorageClass::StorageClass SC =
287 static_cast<SPIRV::StorageClass::StorageClass>(
288 OpType->getOperand(1).getImm());
289 MachineIRBuilder MIB(I);
290 SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(MemberType, MIB, SC);
291 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
292}
293
294// Insert a bitcast before the function call instruction to keep SPIR-V code
295// valid when there is a type mismatch between actual and expected types of an
296// argument:
297// %formal = OpFunctionParameter %formal_type
298// ...
299// %res = OpFunctionCall %ty %fun %actual ...
300// implies that %actual is of %formal_type, and in case of opaque pointers.
301// We may need to insert a bitcast to ensure this.
303 MachineRegisterInfo *DefMRI,
304 MachineRegisterInfo *CallMRI,
305 SPIRVGlobalRegistry &GR, MachineInstr &FunCall,
306 MachineInstr *FunDef) {
307 if (FunDef->getOpcode() != SPIRV::OpFunction)
308 return;
309 unsigned OpIdx = 3;
310 for (FunDef = FunDef->getNextNode();
311 FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&
312 OpIdx < FunCall.getNumOperands();
313 FunDef = FunDef->getNextNode(), OpIdx++) {
314 SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
315 SPIRVType *DefElemType =
316 DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
317 ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
318 DefPtrType->getParent()->getParent())
319 : nullptr;
320 if (DefElemType) {
321 const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
322 // validatePtrTypes() works in the context if the call site
323 // When we process historical records about forward calls
324 // we need to switch context to the (forward) call site and
325 // then restore it back to the current machine function.
326 MachineFunction *CurMF =
327 GR.setCurrentFunc(*FunCall.getParent()->getParent());
328 validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType,
329 DefElemTy);
330 GR.setCurrentFunc(*CurMF);
331 }
332 }
333}
334
335// Ensure there is no mismatch between actual and expected arg types: calls
336// with a processed definition. Return Function pointer if it's a forward
337// call (ahead of definition), and nullptr otherwise.
339 MachineRegisterInfo *CallMRI,
341 MachineInstr &FunCall) {
342 const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
343 const Function *F = dyn_cast<Function>(GV);
344 MachineInstr *FunDef =
345 const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
346 if (!FunDef)
347 return F;
348 MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
349 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
350 return nullptr;
351}
352
353// Ensure there is no mismatch between actual and expected arg types: calls
354// ahead of a processed definition.
357 MachineInstr &FunDef) {
358 const Function *F = GR.getFunctionByDefinition(&FunDef);
360 for (MachineInstr *FunCall : *FwdCalls) {
361 MachineRegisterInfo *CallMRI =
362 &FunCall->getParent()->getParent()->getRegInfo();
363 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, *FunCall, &FunDef);
364 }
365}
366
367// Validation of an access chain.
370 SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(I.getOperand(0).getReg());
371 if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {
372 SPIRVType *BaseElemType =
373 GR.getSPIRVTypeForVReg(BaseTypeInst->getOperand(2).getReg());
374 validatePtrTypes(STI, MRI, GR, I, 2, BaseElemType);
375 }
376}
377
378// TODO: the logic of inserting additional bitcast's is to be moved
379// to pre-IRTranslation passes eventually
381 // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
382 // We'd like to avoid the needless second processing pass.
383 if (ProcessedMF.find(&MF) != ProcessedMF.end())
384 return;
385
387 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
388 GR.setCurrentFunc(MF);
389 for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
392 for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
393 MBBI != MBBE;) {
394 MachineInstr &MI = *MBBI++;
395 switch (MI.getOpcode()) {
396 case SPIRV::OpAtomicLoad:
397 case SPIRV::OpAtomicExchange:
398 case SPIRV::OpAtomicCompareExchange:
399 case SPIRV::OpAtomicCompareExchangeWeak:
400 case SPIRV::OpAtomicIIncrement:
401 case SPIRV::OpAtomicIDecrement:
402 case SPIRV::OpAtomicIAdd:
403 case SPIRV::OpAtomicISub:
404 case SPIRV::OpAtomicSMin:
405 case SPIRV::OpAtomicUMin:
406 case SPIRV::OpAtomicSMax:
407 case SPIRV::OpAtomicUMax:
408 case SPIRV::OpAtomicAnd:
409 case SPIRV::OpAtomicOr:
410 case SPIRV::OpAtomicXor:
411 // for the above listed instructions
412 // OpAtomicXXX <ResType>, ptr %Op, ...
413 // implies that %Op is a pointer to <ResType>
414 case SPIRV::OpLoad:
415 // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
417 break;
418
419 validatePtrTypes(STI, MRI, GR, MI, 2,
420 GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
421 break;
422 case SPIRV::OpAtomicStore:
423 // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>
424 // implies that %Op points to the <Obj>'s type
425 validatePtrTypes(STI, MRI, GR, MI, 0,
426 GR.getSPIRVTypeForVReg(MI.getOperand(3).getReg()));
427 break;
428 case SPIRV::OpStore:
429 // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
430 validatePtrTypes(STI, MRI, GR, MI, 0,
431 GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()));
432 break;
433 case SPIRV::OpPtrCastToGeneric:
434 case SPIRV::OpGenericCastToPtr:
435 case SPIRV::OpGenericCastToPtrExplicit:
436 validateAccessChain(STI, MRI, GR, MI);
437 break;
438 case SPIRV::OpPtrAccessChain:
439 case SPIRV::OpInBoundsPtrAccessChain:
440 if (MI.getNumOperands() == 4)
441 validateAccessChain(STI, MRI, GR, MI);
442 break;
443
444 case SPIRV::OpFunctionCall:
445 // ensure there is no mismatch between actual and expected arg types:
446 // calls with a processed definition
447 if (MI.getNumOperands() > 3)
448 if (const Function *F = validateFunCall(STI, MRI, GR, MI))
449 GR.addForwardCall(F, &MI);
450 break;
451 case SPIRV::OpFunction:
452 // ensure there is no mismatch between actual and expected arg types:
453 // calls ahead of a processed definition
454 validateForwardCalls(STI, MRI, GR, MI);
455 break;
456
457 // ensure that LLVM IR add/sub instructions result in logical SPIR-V
458 // instructions when applied to bool type
459 case SPIRV::OpIAddS:
460 case SPIRV::OpIAddV:
461 case SPIRV::OpISubS:
462 case SPIRV::OpISubV:
463 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
464 SPIRV::OpTypeBool))
465 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
466 break;
467
468 // ensure that LLVM IR bitwise instructions result in logical SPIR-V
469 // instructions when applied to bool type
470 case SPIRV::OpBitwiseOrS:
471 case SPIRV::OpBitwiseOrV:
472 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
473 SPIRV::OpTypeBool))
474 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalOr));
475 break;
476 case SPIRV::OpBitwiseAndS:
477 case SPIRV::OpBitwiseAndV:
478 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
479 SPIRV::OpTypeBool))
480 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalAnd));
481 break;
482 case SPIRV::OpBitwiseXorS:
483 case SPIRV::OpBitwiseXorV:
484 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
485 SPIRV::OpTypeBool))
486 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
487 break;
488 case SPIRV::OpLifetimeStart:
489 case SPIRV::OpLifetimeStop:
490 if (MI.getOperand(1).getImm() > 0)
491 validateLifetimeStart(STI, MRI, GR, MI);
492 break;
493 case SPIRV::OpGroupAsyncCopy:
494 validatePtrUnwrapStructField(STI, MRI, GR, MI, 3);
495 validatePtrUnwrapStructField(STI, MRI, GR, MI, 4);
496 break;
497 case SPIRV::OpGroupWaitEvents:
498 // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
500 break;
501 case SPIRV::OpConstantI: {
502 SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
503 if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() &&
504 MI.getOperand(2).getImm() == 0) {
505 // Validate the null constant of a target extension type
506 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
507 for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
508 MI.removeOperand(i);
509 }
510 } break;
511 case SPIRV::OpPhi: {
512 // Phi refers to a type definition that goes after the Phi
513 // instruction, so that the virtual register definition of the type
514 // doesn't dominate all uses. Let's place the type definition
515 // instruction at the end of the predecessor.
516 MachineBasicBlock *Curr = MI.getParent();
517 SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
518 if (Type->getParent() == Curr && !Curr->pred_empty())
519 ToMove.insert(const_cast<MachineInstr *>(Type));
520 } break;
521 case SPIRV::OpExtInst: {
522 // prefetch
523 if (!MI.getOperand(2).isImm() || !MI.getOperand(3).isImm() ||
524 MI.getOperand(2).getImm() != SPIRV::InstructionSet::OpenCL_std)
525 continue;
526 switch (MI.getOperand(3).getImm()) {
527 case SPIRV::OpenCLExtInst::frexp:
528 case SPIRV::OpenCLExtInst::lgamma_r:
529 case SPIRV::OpenCLExtInst::remquo: {
530 // The last operand must be of a pointer to i32 or vector of i32
531 // values.
532 MachineIRBuilder MIB(MI);
533 SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB);
534 SPIRVType *RetType = MRI->getVRegDef(MI.getOperand(1).getReg());
535 assert(RetType && "Expected return type");
536 validatePtrTypes(STI, MRI, GR, MI, MI.getNumOperands() - 1,
537 RetType->getOpcode() != SPIRV::OpTypeVector
538 ? Int32Type
540 Int32Type, RetType->getOperand(2).getImm(),
541 MIB, false));
542 } break;
543 case SPIRV::OpenCLExtInst::fract:
544 case SPIRV::OpenCLExtInst::modf:
545 case SPIRV::OpenCLExtInst::sincos:
546 // The last operand must be of a pointer to the base type represented
547 // by the previous operand.
548 assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
549 "Expected v-reg");
551 STI, MRI, GR, MI, MI.getNumOperands() - 1,
553 MI.getOperand(MI.getNumOperands() - 2).getReg()));
554 break;
555 case SPIRV::OpenCLExtInst::prefetch:
556 // Expected `ptr` type is a pointer to float, integer or vector, but
557 // the pontee value can be wrapped into a struct.
558 assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
559 "Expected v-reg");
561 MI.getNumOperands() - 2);
562 break;
563 }
564 } break;
565 }
566 }
567 for (MachineInstr *MI : ToMove) {
568 MachineBasicBlock *Curr = MI->getParent();
569 MachineBasicBlock *Pred = *Curr->pred_begin();
570 Pred->insert(Pred->getFirstTerminator(), Curr->remove_instr(MI));
571 }
572 }
573 ProcessedMF.insert(&MF);
575}
576
577// Modifies either operand PtrOpIdx or OpIdx so that the pointee type of
578// PtrOpIdx matches the type for operand OpIdx. Returns true if they already
579// match or if the instruction was modified to make them match.
581 MachineInstr &I, unsigned int PtrOpIdx, unsigned int OpIdx) const {
582 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
583 SPIRVType *PtrType = GR.getResultType(I.getOperand(PtrOpIdx).getReg());
584 SPIRVType *PointeeType = GR.getPointeeType(PtrType);
585 SPIRVType *OpType = GR.getResultType(I.getOperand(OpIdx).getReg());
586
587 if (PointeeType == OpType)
588 return true;
589
590 if (typesLogicallyMatch(PointeeType, OpType, GR)) {
591 // Apply OpCopyLogical to OpIdx.
592 if (I.getOperand(OpIdx).isDef() &&
593 insertLogicalCopyOnResult(I, PointeeType)) {
594 return true;
595 }
596
597 llvm_unreachable("Unable to add OpCopyLogical yet.");
598 return false;
599 }
600
601 return false;
602}
603
605 MachineInstr &I, SPIRVType *NewResultType) const {
606 MachineRegisterInfo *MRI = &I.getMF()->getRegInfo();
607 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
608
609 Register NewResultReg =
610 createVirtualRegister(NewResultType, &GR, MRI, *I.getMF());
611 Register NewTypeReg = GR.getSPIRVTypeID(NewResultType);
612
613 assert(llvm::size(I.defs()) == 1 && "Expected only one def");
614 MachineOperand &OldResult = *I.defs().begin();
615 Register OldResultReg = OldResult.getReg();
616 MachineOperand &OldType = *I.uses().begin();
617 Register OldTypeReg = OldType.getReg();
618
619 OldResult.setReg(NewResultReg);
620 OldType.setReg(NewTypeReg);
621
622 MachineIRBuilder MIB(*I.getNextNode());
623 return MIB.buildInstr(SPIRV::OpCopyLogical)
624 .addDef(OldResultReg)
625 .addUse(OldTypeReg)
626 .addUse(NewResultReg)
627 .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
628 *STI.getRegBankInfo());
629}
unsigned const MachineRegisterInfo * MRI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock & MBB
MachineBasicBlock MachineBasicBlock::iterator MBBI
IRTranslator LLVM IR MI
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
Register const TargetRegisterInfo * TRI
MachineInstr unsigned OpIdx
static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, Register OpReg, unsigned OpIdx, SPIRVType *NewPtrType)
static SPIRVType * createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I, SPIRVType *OpType, bool ReuseType, SPIRVType *ResType, const Type *ResTy)
static void validateLifetimeStart(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)
static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)
static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, unsigned OpIdx)
Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg)
void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)
void validateFunCallMachineDef(const SPIRVSubtarget &STI, MachineRegisterInfo *DefMRI, MachineRegisterInfo *CallMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunCall, MachineInstr *FunDef)
void validateForwardCalls(const SPIRVSubtarget &STI, MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunDef)
const Function * validateFunCall(const SPIRVSubtarget &STI, MachineRegisterInfo *CallMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunCall)
static bool typesLogicallyMatch(const SPIRVType *Ty1, const SPIRVType *Ty2, SPIRVGlobalRegistry &GR)
static void validatePtrTypes(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, unsigned OpIdx, SPIRVType *ResType, const Type *ResTy=nullptr)
This class represents a function call, abstracting a target machine's calling convention.
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
Definition Function.cpp:359
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
Machine Value Type.
bool isVector() const
Return true if this is a vector value type.
bool isInteger() const
Return true if this is an integer or a vector integer type.
bool isFloatingPoint() const
Return true if this is a FP or a vector FP type.
LLVM_ABI MachineInstr * remove_instr(MachineInstr *I)
Remove the possibly bundled instruction from the instruction list without deleting it.
const MachineFunction * getParent() const
Return the MachineFunction containing this basic block.
MachineInstrBundleIterator< MachineInstr > iterator
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Function & getFunction()
Return the LLVM function that this machine code represents.
BasicBlockListType::iterator iterator
void insert(iterator MBBI, MachineBasicBlock *MBB)
Helper class to build MachineInstr.
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
MachineFunction & getMF()
Getter for the function we currently build.
bool constrainAllUses(const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const RegisterBankInfo &RBI) const
const MachineInstrBuilder & addUse(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register use operand.
const MachineInstrBuilder & addDef(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
const MachineBasicBlock * getParent() const
unsigned getNumOperands() const
Retuns the total number of operands.
const MachineOperand & getOperand(unsigned i) const
Flags
Flags values. These may be or'd together.
MachineOperand class - Representation of each machine instruction operand.
const GlobalValue * getGlobal() const
int64_t getImm() const
LLVM_ABI void setReg(Register Reg)
Change the register this operand corresponds to.
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
LLVM_ABI MachineInstr * getVRegDef(Register Reg) const
getVRegDef - Return the machine instr that defines the specified virtual register or null if none is ...
Wrapper class representing virtual and physical registers.
Definition Register.h:20
SPIRVType * getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
void addForwardCall(const Function *F, MachineInstr *MI)
SPIRVType * getResultType(Register VReg, MachineFunction *MF=nullptr)
const Type * getTypeForSPIRVType(const SPIRVType *Ty) const
bool isBitcastCompatible(const SPIRVType *Type1, const SPIRVType *Type2) const
SPIRVType * getOrCreateSPIRVType(const Type *Type, MachineInstr &I, SPIRV::AccessQualifier::AccessQualifier AQ, bool EmitIR)
SPIRVType * getOrCreateSPIRVPointerType(const Type *BaseType, MachineIRBuilder &MIRBuilder, SPIRV::StorageClass::StorageClass SC)
const MachineInstr * getFunctionDefinition(const Function *F)
SPIRVType * getPointeeType(SPIRVType *PtrType)
Register getSPIRVTypeID(const SPIRVType *SpirvType) const
SmallPtrSet< MachineInstr *, 8 > * getForwardCalls(const Function *F)
SPIRVType * getOrCreateSPIRVVectorType(SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder, bool EmitIR)
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const
MachineFunction * setCurrentFunc(MachineFunction &MF)
SPIRVType * getOrCreateSPIRVIntegerType(unsigned BitWidth, MachineIRBuilder &MIRBuilder)
const Function * getFunctionByDefinition(const MachineInstr *MI)
const SPIRVInstrInfo * getInstrInfo() const override
const SPIRVRegisterInfo * getRegisterInfo() const override
const RegisterBankInfo * getRegBankInfo() const override
bool enforcePtrTypeCompatibility(MachineInstr &I, unsigned PtrOpIdx, unsigned OpIdx) const
unsigned getNumRegisters(LLVMContext &Context, EVT VT, std::optional< MVT > RegisterVT=std::nullopt) const override
Return the number of registers that this ValueType will eventually require.
unsigned getNumRegistersForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const override
Certain targets require unusual breakdowns of certain types.
MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const override
Certain combinations of ABIs, Targets and features require that types are legal for some operations a...
bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I, MachineFunction &MF, unsigned Intrinsic) const override
Given an intrinsic, checks if on the target the intrinsic will need to map to a MemIntrinsicNode (tou...
void finalizeLowering(MachineFunction &MF) const override
Execute target specific actions to finalize target lowering.
bool insertLogicalCopyOnResult(MachineInstr &I, SPIRVType *NewResultType) const
SPIRVTargetLowering(const TargetMachine &TM, const SPIRVSubtarget &ST)
std::pair< unsigned, const TargetRegisterClass * > getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, StringRef Constraint, MVT VT) const override
Given a physical register constraint (e.g.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
bool starts_with(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition StringRef.h:261
static LLVM_ABI TargetExtType * get(LLVMContext &Context, StringRef Name, ArrayRef< Type * > Types={}, ArrayRef< unsigned > Ints={})
Return a target extension type having the specified name and optional type and integer parameters.
Definition Type.cpp:907
virtual void finalizeLowering(MachineFunction &MF) const
Execute target specific actions to finalize target lowering.
MVT getRegisterType(MVT VT) const
Return the type of registers that this ValueType will eventually require.
TargetLowering(const TargetLowering &)=delete
Primary interface to the complete machine description for the target machine.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI IntegerType * getInt8Ty(LLVMContext &C)
Definition Type.cpp:294
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
This namespace contains an enum with a value for every intrinsic/builtin function known by LLVM.
This is an optimization pass for GlobalISel generic memory operations.
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition STLExtras.h:1655
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR, MachineRegisterInfo *MRI, const MachineFunction &MF)
MachineInstr * getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI)
LLVM_ABI void report_fatal_error(Error Err, bool gen_crash_diag=true)
Definition Error.cpp:167
const MachineInstr SPIRVType
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition Alignment.h:39
Extended Value Type.
Definition ValueTypes.h:35
TypeSize getSizeInBits() const
Return the size of the specified value type in bits.
Definition ValueTypes.h:373
bool isVector() const
Return true if this is a vector value type.
Definition ValueTypes.h:168
EVT getVectorElementType() const
Given a vector type, return the type of each element.
Definition ValueTypes.h:328
unsigned getVectorNumElements() const
Given a vector type, return the number of elements it contains.
Definition ValueTypes.h:336
bool isInteger() const
Return true if this is an integer or a vector integer type.
Definition ValueTypes.h:152