LLVM  16.0.0git
SPIRVRegularizer.cpp
Go to the documentation of this file.
1 //===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements regularization of LLVM IR for SPIR-V. The prototype of
10 // the pass was taken from SPIRV-LLVM translator.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "SPIRV.h"
15 #include "SPIRVTargetMachine.h"
16 #include "llvm/Demangle/Demangle.h"
17 #include "llvm/IR/InstIterator.h"
18 #include "llvm/IR/InstVisitor.h"
19 #include "llvm/IR/PassManager.h"
21 
22 #include <list>
23 
24 #define DEBUG_TYPE "spirv-regularizer"
25 
26 using namespace llvm;
27 
28 namespace llvm {
30 }
31 
32 namespace {
33 struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> {
35 
36 public:
37  static char ID;
38  SPIRVRegularizer() : FunctionPass(ID) {
40  }
41  bool runOnFunction(Function &F) override;
42  StringRef getPassName() const override { return "SPIR-V Regularizer"; }
43 
44  void getAnalysisUsage(AnalysisUsage &AU) const override {
46  }
47  void visitCallInst(CallInst &CI);
48 
49 private:
50  void visitCallScalToVec(CallInst *CI, StringRef MangledName,
51  StringRef DemangledName);
52  void runLowerConstExpr(Function &F);
53 };
54 } // namespace
55 
56 char SPIRVRegularizer::ID = 0;
57 
58 INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false,
59  false)
60 
61 // Since SPIR-V cannot represent constant expression, constant expressions
62 // in LLVM IR need to be lowered to instructions. For each function,
63 // the constant expressions used by instructions of the function are replaced
64 // by instructions placed in the entry block since it dominates all other BBs.
65 // Each constant expression only needs to be lowered once in each function
66 // and all uses of it by instructions in that function are replaced by
67 // one instruction.
68 // TODO: remove redundant instructions for common subexpression.
69 void SPIRVRegularizer::runLowerConstExpr(Function &F) {
70  LLVMContext &Ctx = F.getContext();
71  std::list<Instruction *> WorkList;
72  for (auto &II : instructions(F))
73  WorkList.push_back(&II);
74 
75  auto FBegin = F.begin();
76  while (!WorkList.empty()) {
77  Instruction *II = WorkList.front();
78 
79  auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * {
80  if (isa<Function>(V))
81  return V;
82  auto *CE = cast<ConstantExpr>(V);
83  LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE);
84  auto ReplInst = CE->getAsInstruction();
85  auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back();
86  ReplInst->insertBefore(InsPoint);
87  LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n');
88  std::vector<Instruction *> Users;
89  // Do not replace use during iteration of use. Do it in another loop.
90  for (auto U : CE->users()) {
91  LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n');
92  auto InstUser = dyn_cast<Instruction>(U);
93  // Only replace users in scope of current function.
94  if (InstUser && InstUser->getParent()->getParent() == &F)
95  Users.push_back(InstUser);
96  }
97  for (auto &User : Users) {
98  if (ReplInst->getParent() == User->getParent() &&
99  User->comesBefore(ReplInst))
100  ReplInst->moveBefore(User);
101  User->replaceUsesOfWith(CE, ReplInst);
102  }
103  return ReplInst;
104  };
105 
106  WorkList.pop_front();
107  auto LowerConstantVec = [&II, &LowerOp, &WorkList,
108  &Ctx](ConstantVector *Vec,
109  unsigned NumOfOp) -> Value * {
110  if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
111  return isa<ConstantExpr>(V) || isa<Function>(V);
112  })) {
113  // Expand a vector of constexprs and construct it back with
114  // series of insertelement instructions.
115  std::list<Value *> OpList;
116  std::transform(Vec->op_begin(), Vec->op_end(),
117  std::back_inserter(OpList),
118  [LowerOp](Value *V) { return LowerOp(V); });
119  Value *Repl = nullptr;
120  unsigned Idx = 0;
121  auto *PhiII = dyn_cast<PHINode>(II);
122  Instruction *InsPoint =
123  PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II;
124  std::list<Instruction *> ReplList;
125  for (auto V : OpList) {
126  if (auto *Inst = dyn_cast<Instruction>(V))
127  ReplList.push_back(Inst);
129  (Repl ? Repl : PoisonValue::get(Vec->getType())), V,
130  ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "", InsPoint);
131  }
132  WorkList.splice(WorkList.begin(), ReplList);
133  return Repl;
134  }
135  return nullptr;
136  };
137  for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {
138  auto *Op = II->getOperand(OI);
139  if (auto *Vec = dyn_cast<ConstantVector>(Op)) {
140  Value *ReplInst = LowerConstantVec(Vec, OI);
141  if (ReplInst)
142  II->replaceUsesOfWith(Op, ReplInst);
143  } else if (auto CE = dyn_cast<ConstantExpr>(Op)) {
144  WorkList.push_front(cast<Instruction>(LowerOp(CE)));
145  } else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) {
146  auto ConstMD = dyn_cast<ConstantAsMetadata>(MDAsVal->getMetadata());
147  if (!ConstMD)
148  continue;
149  Constant *C = ConstMD->getValue();
150  Value *ReplInst = nullptr;
151  if (auto *Vec = dyn_cast<ConstantVector>(C))
152  ReplInst = LowerConstantVec(Vec, OI);
153  if (auto *CE = dyn_cast<ConstantExpr>(C))
154  ReplInst = LowerOp(CE);
155  if (!ReplInst)
156  continue;
157  Metadata *RepMD = ValueAsMetadata::get(ReplInst);
158  Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD);
159  II->setOperand(OI, RepMDVal);
160  WorkList.push_front(cast<Instruction>(ReplInst));
161  }
162  }
163  }
164 }
165 
166 // It fixes calls to OCL builtins that accept vector arguments and one of them
167 // is actually a scalar splat.
168 void SPIRVRegularizer::visitCallInst(CallInst &CI) {
169  auto F = CI.getCalledFunction();
170  if (!F)
171  return;
172 
173  auto MangledName = F->getName();
174  size_t n;
175  int status;
176  char *NameStr = itaniumDemangle(F->getName().data(), nullptr, &n, &status);
177  StringRef DemangledName(NameStr);
178 
179  // TODO: add support for other builtins.
180  if (DemangledName.startswith("fmin") || DemangledName.startswith("fmax") ||
181  DemangledName.startswith("min") || DemangledName.startswith("max"))
182  visitCallScalToVec(&CI, MangledName, DemangledName);
183  free(NameStr);
184 }
185 
186 void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName,
187  StringRef DemangledName) {
188  // Check if all arguments have the same type - it's simple case.
189  auto Uniform = true;
190  Type *Arg0Ty = CI->getOperand(0)->getType();
191  auto IsArg0Vector = isa<VectorType>(Arg0Ty);
192  for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I)
193  Uniform = isa<VectorType>(CI->getOperand(I)->getType()) == IsArg0Vector;
194  if (Uniform)
195  return;
196 
197  auto *OldF = CI->getCalledFunction();
198  Function *NewF = nullptr;
199  if (!Old2NewFuncs.count(OldF)) {
201  SmallVector<Type *, 2> ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty};
202  auto *NewFTy =
203  FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg());
204  NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(),
205  *OldF->getParent());
206  ValueToValueMapTy VMap;
207  auto NewFArgIt = NewF->arg_begin();
208  for (auto &Arg : OldF->args()) {
209  auto ArgName = Arg.getName();
210  NewFArgIt->setName(ArgName);
211  VMap[&Arg] = &(*NewFArgIt++);
212  }
214  CloneFunctionInto(NewF, OldF, VMap,
216  NewF->setAttributes(Attrs);
217  Old2NewFuncs[OldF] = NewF;
218  } else {
219  NewF = Old2NewFuncs[OldF];
220  }
221  assert(NewF);
222 
223  // This produces an instruction sequence that implements a splat of
224  // CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst
225  // and ShuffleVectorInst to generate the same code as the SPIR-V translator.
226  // For instance (transcoding/OpMin.ll), this call
227  // call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> <i32 1, i32 10>, i32 5)
228  // is translated to
229  // %8 = OpUndef %v2uint
230  // %14 = OpConstantComposite %v2uint %uint_1 %uint_10
231  // ...
232  // %10 = OpCompositeInsert %v2uint %uint_5 %8 0
233  // %11 = OpVectorShuffle %v2uint %10 %8 0 0
234  // %call = OpExtInst %v2uint %1 s_min %14 %11
235  auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0);
236  PoisonValue *PVal = PoisonValue::get(Arg0Ty);
237  Instruction *Inst =
238  InsertElementInst::Create(PVal, CI->getOperand(1), ConstInt, "", CI);
239  ElementCount VecElemCount = cast<VectorType>(Arg0Ty)->getElementCount();
240  Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt);
241  Value *NewVec = new ShuffleVectorInst(Inst, PVal, ConstVec, "", CI);
242  CI->setOperand(1, NewVec);
243  CI->replaceUsesOfWith(OldF, NewF);
244  CI->mutateFunctionType(NewF->getFunctionType());
245 }
246 
248  runLowerConstExpr(F);
249  visit(F);
250  for (auto &OldNew : Old2NewFuncs) {
251  Function *OldF = OldNew.first;
252  Function *NewF = OldNew.second;
253  NewF->takeName(OldF);
254  OldF->eraseFromParent();
255  }
256  return true;
257 }
258 
260  return new SPIRVRegularizer();
261 }
Attrs
Function Attrs
Definition: README_ALTIVEC.txt:215
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
llvm::AArch64PACKey::ID
ID
Definition: AArch64BaseInfo.h:818
llvm::ElementCount
Definition: TypeSize.h:404
InstIterator.h
llvm::Function
Definition: Function.h:60
llvm::CallBase::mutateFunctionType
void mutateFunctionType(FunctionType *FTy)
Definition: InstrTypes.h:1257
llvm::Function::eraseFromParent
void eraseFromParent()
eraseFromParent - This method unlinks 'this' from the containing module and deletes it.
Definition: Function.cpp:368
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1199
llvm::FunctionType::get
static FunctionType * get(Type *Result, ArrayRef< Type * > Params, bool isVarArg)
This static method is the primary way of constructing a FunctionType.
Definition: Type.cpp:361
llvm::Instruction::insertBefore
void insertBefore(Instruction *InsertPos)
Insert an unlinked instruction into a basic block immediately before the specified instruction.
Definition: Instruction.cpp:87
llvm::InsertElementInst::Create
static InsertElementInst * Create(Value *Vec, Value *NewElt, Value *Idx, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
Definition: Instructions.h:1950
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
llvm::AttributeList
Definition: Attributes.h:430
llvm::CloneFunctionChangeType::LocalChangesOnly
@ LocalChangesOnly
llvm::ValueAsMetadata::get
static ValueAsMetadata * get(Value *V)
Definition: Metadata.cpp:393
llvm::Type::getInt32Ty
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:239
LLVM_DEBUG
#define LLVM_DEBUG(X)
Definition: Debug.h:101
F
#define F(x, y, z)
Definition: MD5.cpp:55
llvm::dbgs
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
Arg
amdgpu Simplify well known AMD library false FunctionCallee Value * Arg
Definition: AMDGPULibCalls.cpp:187
llvm::all_of
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1734
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:24
E
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
llvm::User
Definition: User.h:44
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
llvm::CallBase::getCalledFunction
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1397
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
llvm::Instruction
Definition: Instruction.h:42
llvm::PassRegistry
PassRegistry - This class manages the registration and intitialization of the pass subsystem as appli...
Definition: PassRegistry.h:38
llvm::ConstantVector::getSplat
static Constant * getSplat(ElementCount EC, Constant *Elt)
Return a ConstantVector with the specified constant in each element.
Definition: Constants.cpp:1391
llvm::ConstantInt::get
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:879
llvm::createSPIRVRegularizerPass
FunctionPass * createSPIRVRegularizerPass()
Definition: SPIRVRegularizer.cpp:259
llvm::Metadata
Root of the metadata hierarchy.
Definition: Metadata.h:62
llvm::CallingConv::ID
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
llvm::Function::getAttributes
AttributeList getAttributes() const
Return the attribute list for this Function.
Definition: Function.h:314
SPIRVTargetMachine.h
llvm::instructions
inst_range instructions(Function *F)
Definition: InstIterator.h:133
llvm::Constant
This is an important base class in LLVM.
Definition: Constant.h:41
llvm::LLVMContext
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:67
llvm::DenseMap
Definition: DenseMap.h:714
I
#define I(x, y, z)
Definition: MD5.cpp:58
Cloning.h
llvm::initializeSPIRVRegularizerPass
void initializeSPIRVRegularizerPass(PassRegistry &)
transform
instcombine should handle this transform
Definition: README.txt:262
llvm::ConstantVector
Constant Vector Declarations.
Definition: Constants.h:494
llvm::Function::Create
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
Definition: Function.h:137
SPIRV.h
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
llvm::User::setOperand
void setOperand(unsigned i, Value *Val)
Definition: User.h:174
Demangle.h
llvm::StringRef
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
InstVisitor.h
INITIALIZE_PASS
INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false, false) void SPIRVRegularizer
Definition: SPIRVRegularizer.cpp:58
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
llvm::Value::getContext
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:994
llvm::InstVisitor
Base class for instruction visitors.
Definition: InstVisitor.h:78
llvm::ValueMap< const Value *, WeakTrackingVH >
llvm::User::replaceUsesOfWith
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition: User.cpp:21
runOnFunction
static bool runOnFunction(Function &F, bool PostInlining)
Definition: EntryExitInstrumenter.cpp:85
DEBUG_TYPE
#define DEBUG_TYPE
Definition: SPIRVRegularizer.cpp:24
llvm::MetadataAsValue::get
static MetadataAsValue * get(LLVMContext &Context, Metadata *MD)
Definition: Metadata.cpp:103
llvm::AMDGPU::SendMsg::Op
Op
Definition: SIDefines.h:348
llvm::CallBase::arg_size
unsigned arg_size() const
Definition: InstrTypes.h:1340
PassManager.h
llvm::Function::getFunctionType
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.h:175
llvm::Function::arg_begin
arg_iterator arg_begin()
Definition: Function.h:722
llvm::ShuffleVectorInst
This instruction constructs a fixed permutation of two input vectors.
Definition: Instructions.h:2008
llvm::User::getNumOperands
unsigned getNumOperands() const
Definition: User.h:191
llvm::sys::fs::status
std::error_code status(const Twine &path, file_status &result, bool follow=true)
Get file status as if by POSIX stat().
llvm::Instruction::getParent
const BasicBlock * getParent() const
Definition: Instruction.h:91
Users
iv Induction Variable Users
Definition: IVUsers.cpp:48
llvm::Pass::getAnalysisUsage
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
llvm::itaniumDemangle
char * itaniumDemangle(const char *mangled_name, char *buf, size_t *n, int *status)
Definition: ItaniumDemangle.cpp:368
llvm::IntegerType::get
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:311
llvm::FunctionPass
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:308
llvm::CallInst
This class represents a function call, abstracting a target machine's calling convention.
Definition: Instructions.h:1474
llvm::CloneFunctionInto
void CloneFunctionInto(Function *NewFunc, const Function *OldFunc, ValueToValueMapTy &VMap, CloneFunctionChangeType Changes, SmallVectorImpl< ReturnInst * > &Returns, const char *NameSuffix="", ClonedCodeInfo *CodeInfo=nullptr, ValueMapTypeRemapper *TypeMapper=nullptr, ValueMaterializer *Materializer=nullptr)
Clone OldFunc into NewFunc, transforming the old arguments into references to VMap values.
Definition: CloneFunction.cpp:86
llvm::Value::takeName
void takeName(Value *V)
Transfer the name from V to this value.
Definition: Value.cpp:381
llvm::User::getOperand
Value * getOperand(unsigned i) const
Definition: User.h:169
n
The same transformation can work with an even modulo with the addition of a and shrink the compare RHS by the same amount Unless the target supports that transformation probably isn t worthwhile The transformation can also easily be made to work with non zero equality for n
Definition: README.txt:685
llvm::Function::setAttributes
void setAttributes(AttributeList Attrs)
Set the attribute list for this Function.
Definition: Function.h:317
llvm::Value
LLVM Value Representation.
Definition: Value.h:74
llvm::PoisonValue
In order to facilitate speculative execution, many instructions do not invoke immediate undefined beh...
Definition: Constants.h:1404
llvm::PoisonValue::get
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1732