53#include "llvm/IR/IntrinsicsSPIRV.h"
60class SPIRVLegalizePointerCastImpl {
67 {Arg->
getType()}, OfType, Arg, {},
B);
68 GR->addAssignPtrTypeInstr(Arg, AssignCI);
71 static FixedVectorType *makeVectorFromTotalBits(
Type *ElemTy,
74 assert(ElemBits && TotalBits % ElemBits == 0 &&
75 "TotalBits must be divisible by element bit size");
80 FixedVectorType *DstTy) {
83 "shuffle resize expects identical element types");
86 const unsigned NumSource = SrcTy->getNumElements();
88 SmallVector<int>
Mask(NumNeeded);
89 for (
unsigned I = 0;
I < NumNeeded; ++
I)
90 Mask[
I] = (
I < NumSource) ?
static_cast<int>(
I) : -1;
92 Value *Resized =
B.CreateShuffleVector(V, V, Mask);
93 buildAssignType(
B, DstTy, Resized);
102 FixedVectorType *TargetType,
Value *Source) {
103 LoadInst *NewLoad =
B.CreateLoad(SourceType, Source);
104 buildAssignType(
B, SourceType, NewLoad);
105 Value *AssignValue = NewLoad;
107 const DataLayout &
DL =
B.GetInsertBlock()->getModule()->getDataLayout();
108 TypeSize TargetTypeSize =
DL.getTypeSizeInBits(TargetType);
109 TypeSize SourceTypeSize =
DL.getTypeSizeInBits(SourceType);
111 Value *BitcastSrcVal = NewLoad;
112 FixedVectorType *BitcastSrcTy =
114 FixedVectorType *BitcastDstTy = TargetType;
116 if (TargetTypeSize != SourceTypeSize) {
117 unsigned TargetElemBits =
119 if (SourceTypeSize % TargetElemBits == 0) {
122 BitcastDstTy = makeVectorFromTotalBits(TargetType->
getElementType(),
126 BitcastSrcTy = makeVectorFromTotalBits(SourceType->
getElementType(),
128 BitcastSrcVal = resizeVectorBitsWithShuffle(
B, NewLoad, BitcastSrcTy);
132 B.CreateIntrinsic(Intrinsic::spv_bitcast,
133 {BitcastDstTy, BitcastSrcTy}, {BitcastSrcVal});
134 buildAssignType(
B, BitcastDstTy, AssignValue);
135 if (BitcastDstTy == TargetType)
143 Value *Output =
B.CreateShuffleVector(AssignValue, AssignValue, Mask);
144 buildAssignType(
B, TargetType, Output);
152 Value *Source, LoadInst *BadLoad) {
155 SmallVector<Value *, 8>
Args{
B.getInt1(
false),
Source};
157 Type *AggregateType = GR->findDeducedElementType(Source);
158 assert(AggregateType &&
"Could not deduce aggregate type");
159 buildGEPIndexChain(
B, ElementType, AggregateType, Args);
161 auto *
GEP =
B.CreateIntrinsic(Intrinsic::spv_gep, {
Types}, {
Args});
162 GR->buildAssignPtr(
B, ElementType,
GEP);
164 LoadInst *LI =
B.CreateLoad(ElementType,
GEP);
166 buildAssignType(
B, ElementType, LI);
170 buildVectorFromLoadedElements(
IRBuilder<> &
B, FixedVectorType *TargetType,
171 SmallVector<Value *, 4> &LoadedElements) {
174 buildAssignType(
B, TargetType, NewVector);
182 NewVector =
B.CreateIntrinsic(Intrinsic::spv_insertelt, {
Types}, {
Args});
183 buildAssignType(
B, TargetType, NewVector);
192 FixedVectorType *ArrElemVecTy) {
196 SmallVector<Value *, 4> LoadedElements;
199 unsigned ArrayIndex =
I / ScalarsPerArrayElement;
200 unsigned ElementIndexInArrayElem =
I % ScalarsPerArrayElement;
202 std::array<Value *, 4>
Args = {
203 B.getInt1(
false),
Source,
B.getInt32(0),
204 ConstantInt::get(
B.getInt32Ty(), ArrayIndex)};
205 auto *ElementPtr =
B.CreateIntrinsic(Intrinsic::spv_gep, {
Types}, {
Args});
206 GR->buildAssignPtr(
B, ArrElemVecTy, ElementPtr);
207 Value *LoadVec =
B.CreateLoad(ArrElemVecTy, ElementPtr);
208 buildAssignType(
B, ArrElemVecTy, LoadVec);
209 LoadedElements.
push_back(makeExtractElement(
B, TargetElemTy, LoadVec,
210 ElementIndexInArrayElem));
212 return buildVectorFromLoadedElements(
B, TargetType, LoadedElements);
218 SmallVector<Value *, 4> LoadedElements;
222 std::array<Value *, 4>
Args = {
B.getInt1(
false),
Source,
224 ConstantInt::get(
B.getInt32Ty(),
I)};
225 auto *ElementPtr =
B.CreateIntrinsic(Intrinsic::spv_gep, {
Types}, {
Args});
233 return buildVectorFromLoadedElements(
B, TargetType, LoadedElements);
238 Value *DstArrayPtr, ArrayType *ArrTy,
244 unsigned SrcNumElements = SrcVecTy->getNumElements();
246 SrcNumElements % ScalarsPerArrayElement == 0 &&
247 "Source vector size must be a multiple of array element vector size");
252 for (
unsigned I = 0;
I < SrcNumElements;
I += ScalarsPerArrayElement) {
253 unsigned ArrayIndex =
I / ScalarsPerArrayElement;
255 std::array<Value *, 4>
Args = {
256 B.getInt1(
false), DstArrayPtr,
B.getInt32(0),
257 ConstantInt::get(
B.getInt32Ty(), ArrayIndex)};
258 auto *ElementPtr =
B.CreateIntrinsic(Intrinsic::spv_gep, {
Types}, {
Args});
259 GR->buildAssignPtr(
B, ArrElemVecTy, ElementPtr);
263 for (
unsigned J = 0; J < ScalarsPerArrayElement; ++J)
264 Elements.push_back(makeExtractElement(
B, ElemTy, SrcVector,
I + J));
267 Value *Vec = buildVectorFromLoadedElements(
B, ArrElemVecTy, Elements);
268 StoreInst *
SI =
B.CreateStore(Vec, ElementPtr);
269 SI->setAlignment(Alignment);
275 Value *DstArrayPtr, ArrayType *ArrTy,
278 Type *ElemTy = ArrTy->getElementType();
281 assert(VecTy->getElementType() == ElemTy &&
282 "Element types of array and vector must be the same.");
286 for (
unsigned I = 0,
E = VecTy->getNumElements();
I <
E; ++
I) {
288 std::array<Value *, 4>
Args = {
B.getInt1(
false), DstArrayPtr,
290 ConstantInt::get(
B.getInt32Ty(),
I)};
291 auto *ElementPtr =
B.CreateIntrinsic(Intrinsic::spv_gep, {
Types}, {
Args});
292 GR->buildAssignPtr(
B, ElemTy, ElementPtr);
295 Value *Element = makeExtractElement(
B, ElemTy, SrcVector,
I);
296 StoreInst *
SI =
B.CreateStore(Element, ElementPtr);
297 SI->setAlignment(Alignment);
304 Value *OriginalOperand) {
305 Type *FromTy = GR->findDeducedElementType(OriginalOperand);
306 Type *ToTy = GR->findDeducedElementType(CastedOperand);
307 Value *Output =
nullptr;
315 B.SetInsertPoint(LI);
320 if (isTypeFirstElementAggregate(ToTy, FromTy))
321 Output = loadFirstValueFromAggregate(
B, ToTy, OriginalOperand, LI);
326 Output = loadVectorFromVector(
B, SVT, DVT, OriginalOperand);
327 else if (
SAT && DVT &&
SAT->getElementType() == DVT->getElementType())
328 Output = loadVectorFromArray(
B, DVT, OriginalOperand);
329 else if (MAT && DVT && MAT->getElementType() == DVT->getElementType())
330 Output = loadVectorFromMatrixArray(
B, DVT, OriginalOperand, MAT);
334 GR->replaceAllUsesWith(LI, Output,
true);
335 DeadInstructions.push_back(LI);
346 B.CreateIntrinsic(Intrinsic::spv_insertelt, {
Types}, {
Args});
347 buildAssignType(
B,
Vector->getType(), NewI);
359 B.CreateIntrinsic(Intrinsic::spv_extractelt, {
Types}, {
Args});
360 buildAssignType(
B, ElementType, NewI);
369 FixedVectorType *DstType =
378 [[maybe_unused]]
auto dstBitWidth =
380 [[maybe_unused]]
auto srcBitWidth =
382 assert(dstBitWidth == srcBitWidth &&
383 "Unsupported bitcast between vectors of different sizes.");
386 B.CreateIntrinsic(Intrinsic::spv_bitcast, {DstType, SrcType}, {Src});
387 buildAssignType(
B, DstType, Src);
390 StoreInst *
SI =
B.CreateStore(Src, Dst);
391 SI->setAlignment(Alignment);
396 LoadInst *LI =
B.CreateLoad(DstType, Dst);
398 Value *OldValues = LI;
399 buildAssignType(
B, OldValues->
getType(), OldValues);
400 Value *NewValues = Src;
405 OldValues = makeInsertElement(
B, OldValues, Element,
I);
408 StoreInst *
SI =
B.CreateStore(OldValues, Dst);
409 SI->setAlignment(Alignment);
414 SmallVectorImpl<Value *> &Indices) {
417 if (Search == Aggregate)
421 buildGEPIndexChain(
B, Search,
ST->getTypeAtIndex(0u), Indices);
423 buildGEPIndexChain(
B, Search, AT->getElementType(), Indices);
425 buildGEPIndexChain(
B, Search, VT->getElementType(), Indices);
432 Type *DstPointeeType, Align Alignment) {
433 std::array<Type *, 2>
Types = {Dst->getType(), Dst->getType()};
434 SmallVector<Value *, 8>
Args{
B.getInt1(
true), Dst};
435 buildGEPIndexChain(
B, Src->getType(), DstPointeeType, Args);
436 auto *
GEP =
B.CreateIntrinsic(Intrinsic::spv_gep, {
Types}, {
Args});
437 GR->buildAssignPtr(
B, Src->getType(),
GEP);
438 StoreInst *
SI =
B.CreateStore(Src,
GEP);
439 SI->setAlignment(Alignment);
443 bool isTypeFirstElementAggregate(
Type *Search,
Type *Aggregate) {
444 if (Search == Aggregate)
447 return isTypeFirstElementAggregate(Search,
ST->getTypeAtIndex(0u));
449 return isTypeFirstElementAggregate(Search, VT->getElementType());
451 return isTypeFirstElementAggregate(Search, AT->getElementType());
458 Value *Dst, Align Alignment) {
459 Type *ToTy = GR->findDeducedElementType(Dst);
460 Type *FromTy = Src->getType();
468 B.SetInsertPoint(BadStore);
469 if (isTypeFirstElementAggregate(FromTy, ToTy))
470 storeToFirstValueAggregate(
B, Src, Dst, ToTy, Alignment);
471 else if (D_VT && S_VT)
472 storeVectorFromVector(
B, Src, Dst, Alignment);
473 else if (D_VT && !S_VT && FromTy == D_VT->getElementType())
474 storeToFirstValueAggregate(
B, Src, Dst, D_VT, Alignment);
475 else if (D_AT && S_VT && S_VT->getElementType() == D_AT->getElementType())
476 storeArrayFromVector(
B, Src, Dst, D_AT, Alignment);
477 else if (D_MAT && S_VT && D_MAT->getElementType() == S_VT->getElementType())
478 storeMatrixArrayFromVector(
B, Src, Dst, D_AT, Alignment);
482 DeadInstructions.push_back(BadStore);
485 void legalizePointerCast(IntrinsicInst *
II) {
487 Value *OriginalOperand =
II->getOperand(0);
490 std::vector<Value *>
Users;
491 for (Use &U :
II->uses())
492 Users.push_back(
U.getUser());
496 transformLoad(
B, LI, CastedOperand, OriginalOperand);
501 transformStore(
B, SI,
SI->getValueOperand(), OriginalOperand,
507 if (Intrin->getIntrinsicID() == Intrinsic::spv_assign_ptr_type) {
508 DeadInstructions.push_back(Intrin);
512 if (Intrin->getIntrinsicID() == Intrinsic::spv_gep) {
513 GR->replaceAllUsesWith(CastedOperand, OriginalOperand,
518 if (Intrin->getIntrinsicID() == Intrinsic::spv_store) {
521 Alignment =
Align(
C->getZExtValue());
522 transformStore(
B, Intrin, Intrin->getArgOperand(0), OriginalOperand,
531 DeadInstructions.push_back(
II);
535 SPIRVLegalizePointerCastImpl(
const SPIRVTargetMachine &TM) : TM(TM) {}
537 bool run(Function &
F) {
538 const SPIRVSubtarget &
ST = TM.getSubtarget<SPIRVSubtarget>(
F);
539 GR =
ST.getSPIRVGlobalRegistry();
540 DeadInstructions.clear();
542 std::vector<IntrinsicInst *> WorkList;
546 if (
II &&
II->getIntrinsicID() == Intrinsic::spv_ptrcast)
547 WorkList.push_back(
II);
551 for (IntrinsicInst *
II : WorkList)
552 legalizePointerCast(
II);
554 for (Instruction *
I : DeadInstructions)
555 I->eraseFromParent();
557 return DeadInstructions.size() != 0;
561 const SPIRVTargetMachine &TM;
562 SPIRVGlobalRegistry *GR =
nullptr;
563 std::vector<Instruction *> DeadInstructions;
566class SPIRVLegalizePointerCastLegacy :
public FunctionPass {
569 SPIRVLegalizePointerCastLegacy(
const SPIRVTargetMachine &TM)
570 : FunctionPass(ID), TM(TM) {}
573 return SPIRVLegalizePointerCastImpl(TM).run(
F);
577 const SPIRVTargetMachine &TM;
587char SPIRVLegalizePointerCastLegacy::ID = 0;
589 "SPIRV legalize pointer cast pass",
false,
false)
592 return new SPIRVLegalizePointerCastLegacy(*TM);
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static bool runOnFunction(Function &F, bool PostInlining)
iv Induction Variable Users
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
unsigned getNumElements() const
static LLVM_ABI FixedVectorType * get(Type *ElementType, unsigned NumElts)
FunctionPass class - This class is used to implement most global optimizations.
void setAlignment(Align Align)
Type * getPointerOperandType() const
Align getAlign() const
Return the alignment of the access that is being performed.
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void push_back(const T &Elt)
LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Type * getType() const
All values are typed, get the type of this value.
Type * getElementType() const
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
PointerTypeMap run(const Module &M)
Compute the PointerTypeMap for the module M.
ElementType
The element type of an SRV or UAV resource.
friend class Instruction
Iterator for Instructions in a `BasicBlock.
This is an optimization pass for GlobalISel generic memory operations.
FunctionAddr VTableAddr Value
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
FunctionAddr VTableAddr uintptr_t uintptr_t Int32Ty
CallInst * buildIntrWithMD(Intrinsic::ID IntrID, ArrayRef< Type * > Types, Value *Arg, Value *Arg2, ArrayRef< Constant * > Imms, IRBuilder<> &B)
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
FunctionPass * createSPIRVLegalizePointerCastPass(SPIRVTargetMachine *TM)