22#include "llvm/IR/IntrinsicsSPIRV.h"
24#define DEBUG_TYPE "spirv-lower"
60 unsigned Intrinsic)
const {
61 unsigned AlignIdx = 3;
63 case Intrinsic::spv_load:
66 case Intrinsic::spv_store: {
67 if (
I.getNumOperands() >= AlignIdx + 1) {
68 auto *AlignOp = cast<ConstantInt>(
I.getOperand(AlignIdx));
69 Info.align =
Align(AlignOp->getZExtValue());
72 cast<ConstantInt>(
I.getOperand(AlignIdx - 1))->getZExtValue());
73 Info.memVT = MVT::i64;
93 Register OpReg =
I.getOperand(OpIdx).getReg();
96 TypeInst && TypeInst->
getOpcode() == SPIRV::OpFunctionParameter
100 if (!ResType || !OpType || OpType->
getOpcode() != SPIRV::OpTypePointer)
109 bool IsEqualTypes = IsSameMF ? ElemType == ResType
115 SPIRV::StorageClass::StorageClass SC =
116 static_cast<SPIRV::StorageClass::StorageClass
>(
122 ResTy, MIB, SPIRV::AccessQualifier::ReadWrite,
false);
126 "insert validation bitcast: incompatible result and operand types");
136 MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
138 I.getOperand(OpIdx).setReg(NewReg);
154 if (FunDef->
getOpcode() != SPIRV::OpFunction)
158 FunDef && FunDef->
getOpcode() == SPIRV::OpFunctionParameter &&
163 DefPtrType && DefPtrType->
getOpcode() == SPIRV::OpTypePointer
190 const Function *
F = dyn_cast<Function>(GV);
209 &FunCall->getParent()->getParent()->getRegInfo();
218 if (BaseTypeInst && BaseTypeInst->
getOpcode() == SPIRV::OpTypePointer) {
230 if (ProcessedMF.find(&MF) != ProcessedMF.end())
241 switch (
MI.getOpcode()) {
242 case SPIRV::OpAtomicLoad:
243 case SPIRV::OpAtomicExchange:
244 case SPIRV::OpAtomicCompareExchange:
245 case SPIRV::OpAtomicCompareExchangeWeak:
246 case SPIRV::OpAtomicIIncrement:
247 case SPIRV::OpAtomicIDecrement:
248 case SPIRV::OpAtomicIAdd:
249 case SPIRV::OpAtomicISub:
250 case SPIRV::OpAtomicSMin:
251 case SPIRV::OpAtomicUMin:
252 case SPIRV::OpAtomicSMax:
253 case SPIRV::OpAtomicUMax:
254 case SPIRV::OpAtomicAnd:
255 case SPIRV::OpAtomicOr:
256 case SPIRV::OpAtomicXor:
265 case SPIRV::OpAtomicStore:
276 case SPIRV::OpPtrCastToGeneric:
279 case SPIRV::OpInBoundsPtrAccessChain:
280 if (
MI.getNumOperands() == 4)
284 case SPIRV::OpFunctionCall:
287 if (
MI.getNumOperands() > 3)
291 case SPIRV::OpFunction:
299 case SPIRV::OpBitwiseOrS:
300 case SPIRV::OpBitwiseOrV:
305 case SPIRV::OpBitwiseAndS:
306 case SPIRV::OpBitwiseAndV:
311 case SPIRV::OpBitwiseXorS:
312 case SPIRV::OpBitwiseXorV:
320 ProcessedMF.insert(&MF);
unsigned const MachineRegisterInfo * MRI
MachineBasicBlock MachineBasicBlock::iterator MBBI
Analysis containing CSE Info
void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)
void validateFunCallMachineDef(const SPIRVSubtarget &STI, MachineRegisterInfo *DefMRI, MachineRegisterInfo *CallMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunCall, MachineInstr *FunDef)
void validateForwardCalls(const SPIRVSubtarget &STI, MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunDef)
const Function * validateFunCall(const SPIRVSubtarget &STI, MachineRegisterInfo *CallMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunCall)
static void validatePtrTypes(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, unsigned OpIdx, SPIRVType *ResType, const Type *ResTy=nullptr)
This class represents a function call, abstracting a target machine's calling convention.
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
This is an important class for using LLVM in a threaded context.
const MachineFunction * getParent() const
Return the MachineFunction containing this basic block.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Helper class to build MachineInstr.
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
MachineFunction & getMF()
Getter for the function we currently build.
bool constrainAllUses(const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const RegisterBankInfo &RBI) const
const MachineInstrBuilder & addUse(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register use operand.
const MachineInstrBuilder & addDef(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
const MachineBasicBlock * getParent() const
unsigned getNumOperands() const
Retuns the total number of operands.
const MachineOperand & getOperand(unsigned i) const
Flags
Flags values. These may be or'd together.
const GlobalValue * getGlobal() const
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
MachineInstr * getVRegDef(Register Reg) const
getVRegDef - Return the machine instr that defines the specified virtual register or null if none is ...
Wrapper class representing virtual and physical registers.
SPIRVType * getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
void addForwardCall(const Function *F, MachineInstr *MI)
const Type * getTypeForSPIRVType(const SPIRVType *Ty) const
bool isBitcastCompatible(const SPIRVType *Type1, const SPIRVType *Type2) const
const MachineInstr * getFunctionDefinition(const Function *F)
Register getSPIRVTypeID(const SPIRVType *SpirvType) const
SPIRVType * getOrCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AQ=SPIRV::AccessQualifier::ReadWrite, bool EmitIR=true)
void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg, MachineFunction &MF)
SmallPtrSet< MachineInstr *, 8 > * getForwardCalls(const Function *F)
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const
MachineFunction * setCurrentFunc(MachineFunction &MF)
SPIRVType * getOrCreateSPIRVPointerType(SPIRVType *BaseType, MachineIRBuilder &MIRBuilder, SPIRV::StorageClass::StorageClass SClass=SPIRV::StorageClass::Function)
const Function * getFunctionByDefinition(const MachineInstr *MI)
const SPIRVInstrInfo * getInstrInfo() const override
SPIRVGlobalRegistry * getSPIRVGlobalRegistry() const
const SPIRVRegisterInfo * getRegisterInfo() const override
const RegisterBankInfo * getRegBankInfo() const override
unsigned getNumRegistersForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const override
Certain targets require unusual breakdowns of certain types.
MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const override
Certain combinations of ABIs, Targets and features require that types are legal for some operations a...
bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I, MachineFunction &MF, unsigned Intrinsic) const override
Given an intrinsic, checks if on the target the intrinsic will need to map to a MemIntrinsicNode (tou...
void finalizeLowering(MachineFunction &MF) const override
Execute target specific actions to finalize target lowering.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
virtual void finalizeLowering(MachineFunction &MF) const
Execute target specific actions to finalize target lowering.
virtual unsigned getNumRegisters(LLVMContext &Context, EVT VT, std::optional< MVT > RegisterVT=std::nullopt) const
Return the number of registers that this ValueType will eventually require.
MVT getRegisterType(MVT VT) const
Return the type of registers that this ValueType will eventually require.
The instances of the Type class are immutable: once they are created, they are never changed.
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
This is an optimization pass for GlobalISel generic memory operations.
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
This struct is a compact representation of a valid (non-zero power of two) alignment.
TypeSize getSizeInBits() const
Return the size of the specified value type in bits.
bool isVector() const
Return true if this is a vector value type.
EVT getVectorElementType() const
Given a vector type, return the type of each element.
unsigned getVectorNumElements() const
Given a vector type, return the number of elements it contains.
bool isInteger() const
Return true if this is an integer or a vector integer type.