LLVM  10.0.0svn
WebAssemblyFixFunctionBitcasts.cpp
Go to the documentation of this file.
1 //===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Fix bitcasted functions.
11 ///
12 /// WebAssembly requires caller and callee signatures to match, however in LLVM,
13 /// some amount of slop is vaguely permitted. Detect mismatch by looking for
14 /// bitcasts of functions and rewrite them to use wrapper functions instead.
15 ///
16 /// This doesn't catch all cases, such as when a function's address is taken in
17 /// one place and casted in another, but it works for many common cases.
18 ///
19 /// Note that LLVM already optimizes away function bitcasts in common cases by
20 /// dropping arguments as needed, so this pass only ends up getting used in less
21 /// common cases.
22 ///
23 //===----------------------------------------------------------------------===//
24 
25 #include "WebAssembly.h"
26 #include "llvm/IR/CallSite.h"
27 #include "llvm/IR/Constants.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/IR/Operator.h"
31 #include "llvm/Pass.h"
32 #include "llvm/Support/Debug.h"
34 using namespace llvm;
35 
36 #define DEBUG_TYPE "wasm-fix-function-bitcasts"
37 
38 namespace {
39 class FixFunctionBitcasts final : public ModulePass {
40  StringRef getPassName() const override {
41  return "WebAssembly Fix Function Bitcasts";
42  }
43 
44  void getAnalysisUsage(AnalysisUsage &AU) const override {
45  AU.setPreservesCFG();
47  }
48 
49  bool runOnModule(Module &M) override;
50 
51 public:
52  static char ID;
53  FixFunctionBitcasts() : ModulePass(ID) {}
54 };
55 } // End anonymous namespace
56 
58 INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE,
59  "Fix mismatching bitcasts for WebAssembly", false, false)
60 
62  return new FixFunctionBitcasts();
63 }
64 
65 // Recursively descend the def-use lists from V to find non-bitcast users of
66 // bitcasts of V.
67 static void findUses(Value *V, Function &F,
68  SmallVectorImpl<std::pair<Use *, Function *>> &Uses,
69  SmallPtrSetImpl<Constant *> &ConstantBCs) {
70  for (Use &U : V->uses()) {
71  if (auto *BC = dyn_cast<BitCastOperator>(U.getUser()))
72  findUses(BC, F, Uses, ConstantBCs);
73  else if (auto *A = dyn_cast<GlobalAlias>(U.getUser()))
74  findUses(A, F, Uses, ConstantBCs);
75  else if (U.get()->getType() != F.getType()) {
76  CallSite CS(U.getUser());
77  if (!CS)
78  // Skip uses that aren't immediately called
79  continue;
80  Value *Callee = CS.getCalledValue();
81  if (Callee != V)
82  // Skip calls where the function isn't the callee
83  continue;
84  if (isa<Constant>(U.get())) {
85  // Only add constant bitcasts to the list once; they get RAUW'd
86  auto C = ConstantBCs.insert(cast<Constant>(U.get()));
87  if (!C.second)
88  continue;
89  }
90  Uses.push_back(std::make_pair(&U, &F));
91  }
92  }
93 }
94 
95 // Create a wrapper function with type Ty that calls F (which may have a
96 // different type). Attempt to support common bitcasted function idioms:
97 // - Call with more arguments than needed: arguments are dropped
98 // - Call with fewer arguments than needed: arguments are filled in with undef
99 // - Return value is not needed: drop it
100 // - Return value needed but not present: supply an undef
101 //
102 // If the all the argument types of trivially castable to one another (i.e.
103 // I32 vs pointer type) then we don't create a wrapper at all (return nullptr
104 // instead).
105 //
106 // If there is a type mismatch that we know would result in an invalid wasm
107 // module then generate wrapper that contains unreachable (i.e. abort at
108 // runtime). Such programs are deep into undefined behaviour territory,
109 // but we choose to fail at runtime rather than generate and invalid module
110 // or fail at compiler time. The reason we delay the error is that we want
111 // to support the CMake which expects to be able to compile and link programs
112 // that refer to functions with entirely incorrect signatures (this is how
113 // CMake detects the existence of a function in a toolchain).
114 //
115 // For bitcasts that involve struct types we don't know at this stage if they
116 // would be equivalent at the wasm level and so we can't know if we need to
117 // generate a wrapper.
119  Module *M = F->getParent();
120 
122  F->getName() + "_bitcast", M);
123  BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
124  const DataLayout &DL = BB->getModule()->getDataLayout();
125 
126  // Determine what arguments to pass.
128  Function::arg_iterator AI = Wrapper->arg_begin();
129  Function::arg_iterator AE = Wrapper->arg_end();
132  bool TypeMismatch = false;
133  bool WrapperNeeded = false;
134 
135  Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
136  Type *RtnType = Ty->getReturnType();
137 
138  if ((F->getFunctionType()->getNumParams() != Ty->getNumParams()) ||
139  (F->getFunctionType()->isVarArg() != Ty->isVarArg()) ||
140  (ExpectedRtnType != RtnType))
141  WrapperNeeded = true;
142 
143  for (; AI != AE && PI != PE; ++AI, ++PI) {
144  Type *ArgType = AI->getType();
145  Type *ParamType = *PI;
146 
147  if (ArgType == ParamType) {
148  Args.push_back(&*AI);
149  } else {
150  if (CastInst::isBitOrNoopPointerCastable(ArgType, ParamType, DL)) {
151  Instruction *PtrCast =
152  CastInst::CreateBitOrPointerCast(AI, ParamType, "cast");
153  BB->getInstList().push_back(PtrCast);
154  Args.push_back(PtrCast);
155  } else if (ArgType->isStructTy() || ParamType->isStructTy()) {
156  LLVM_DEBUG(dbgs() << "createWrapper: struct param type in bitcast: "
157  << F->getName() << "\n");
158  WrapperNeeded = false;
159  } else {
160  LLVM_DEBUG(dbgs() << "createWrapper: arg type mismatch calling: "
161  << F->getName() << "\n");
162  LLVM_DEBUG(dbgs() << "Arg[" << Args.size() << "] Expected: "
163  << *ParamType << " Got: " << *ArgType << "\n");
164  TypeMismatch = true;
165  break;
166  }
167  }
168  }
169 
170  if (WrapperNeeded && !TypeMismatch) {
171  for (; PI != PE; ++PI)
172  Args.push_back(UndefValue::get(*PI));
173  if (F->isVarArg())
174  for (; AI != AE; ++AI)
175  Args.push_back(&*AI);
176 
177  CallInst *Call = CallInst::Create(F, Args, "", BB);
178 
179  Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
180  Type *RtnType = Ty->getReturnType();
181  // Determine what value to return.
182  if (RtnType->isVoidTy()) {
183  ReturnInst::Create(M->getContext(), BB);
184  } else if (ExpectedRtnType->isVoidTy()) {
185  LLVM_DEBUG(dbgs() << "Creating dummy return: " << *RtnType << "\n");
186  ReturnInst::Create(M->getContext(), UndefValue::get(RtnType), BB);
187  } else if (RtnType == ExpectedRtnType) {
188  ReturnInst::Create(M->getContext(), Call, BB);
189  } else if (CastInst::isBitOrNoopPointerCastable(ExpectedRtnType, RtnType,
190  DL)) {
191  Instruction *Cast =
192  CastInst::CreateBitOrPointerCast(Call, RtnType, "cast");
193  BB->getInstList().push_back(Cast);
194  ReturnInst::Create(M->getContext(), Cast, BB);
195  } else if (RtnType->isStructTy() || ExpectedRtnType->isStructTy()) {
196  LLVM_DEBUG(dbgs() << "createWrapper: struct return type in bitcast: "
197  << F->getName() << "\n");
198  WrapperNeeded = false;
199  } else {
200  LLVM_DEBUG(dbgs() << "createWrapper: return type mismatch calling: "
201  << F->getName() << "\n");
202  LLVM_DEBUG(dbgs() << "Expected: " << *ExpectedRtnType
203  << " Got: " << *RtnType << "\n");
204  TypeMismatch = true;
205  }
206  }
207 
208  if (TypeMismatch) {
209  // Create a new wrapper that simply contains `unreachable`.
210  Wrapper->eraseFromParent();
212  F->getName() + "_bitcast_invalid", M);
213  BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
214  new UnreachableInst(M->getContext(), BB);
215  Wrapper->setName(F->getName() + "_bitcast_invalid");
216  } else if (!WrapperNeeded) {
217  LLVM_DEBUG(dbgs() << "createWrapper: no wrapper needed: " << F->getName()
218  << "\n");
219  Wrapper->eraseFromParent();
220  return nullptr;
221  }
222  LLVM_DEBUG(dbgs() << "createWrapper: " << F->getName() << "\n");
223  return Wrapper;
224 }
225 
226 // Test whether a main function with type FuncTy should be rewritten to have
227 // type MainTy.
228 static bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy) {
229  // Only fix the main function if it's the standard zero-arg form. That way,
230  // the standard cases will work as expected, and users will see signature
231  // mismatches from the linker for non-standard cases.
232  return FuncTy->getReturnType() == MainTy->getReturnType() &&
233  FuncTy->getNumParams() == 0 &&
234  !FuncTy->isVarArg();
235 }
236 
237 bool FixFunctionBitcasts::runOnModule(Module &M) {
238  LLVM_DEBUG(dbgs() << "********** Fix Function Bitcasts **********\n");
239 
240  Function *Main = nullptr;
241  CallInst *CallMain = nullptr;
243  SmallPtrSet<Constant *, 2> ConstantBCs;
244 
245  // Collect all the places that need wrappers.
246  for (Function &F : M) {
247  findUses(&F, F, Uses, ConstantBCs);
248 
249  // If we have a "main" function, and its type isn't
250  // "int main(int argc, char *argv[])", create an artificial call with it
251  // bitcasted to that type so that we generate a wrapper for it, so that
252  // the C runtime can call it.
253  if (F.getName() == "main") {
254  Main = &F;
255  LLVMContext &C = M.getContext();
256  Type *MainArgTys[] = {Type::getInt32Ty(C),
258  FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys,
259  /*isVarArg=*/false);
260  if (shouldFixMainFunction(F.getFunctionType(), MainTy)) {
261  LLVM_DEBUG(dbgs() << "Found `main` function with incorrect type: "
262  << *F.getFunctionType() << "\n");
263  Value *Args[] = {UndefValue::get(MainArgTys[0]),
264  UndefValue::get(MainArgTys[1])};
265  Value *Casted =
266  ConstantExpr::getBitCast(Main, PointerType::get(MainTy, 0));
267  CallMain = CallInst::Create(MainTy, Casted, Args, "call_main");
268  Use *UseMain = &CallMain->getOperandUse(2);
269  Uses.push_back(std::make_pair(UseMain, &F));
270  }
271  }
272  }
273 
275 
276  for (auto &UseFunc : Uses) {
277  Use *U = UseFunc.first;
278  Function *F = UseFunc.second;
279  auto *PTy = cast<PointerType>(U->get()->getType());
280  auto *Ty = dyn_cast<FunctionType>(PTy->getElementType());
281 
282  // If the function is casted to something like i8* as a "generic pointer"
283  // to be later casted to something else, we can't generate a wrapper for it.
284  // Just ignore such casts for now.
285  if (!Ty)
286  continue;
287 
288  auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));
289  if (Pair.second)
290  Pair.first->second = createWrapper(F, Ty);
291 
292  Function *Wrapper = Pair.first->second;
293  if (!Wrapper)
294  continue;
295 
296  if (isa<Constant>(U->get()))
297  U->get()->replaceAllUsesWith(Wrapper);
298  else
299  U->set(Wrapper);
300  }
301 
302  // If we created a wrapper for main, rename the wrapper so that it's the
303  // one that gets called from startup.
304  if (CallMain) {
305  Main->setName("__original_main");
306  auto *MainWrapper =
307  cast<Function>(CallMain->getCalledValue()->stripPointerCasts());
308  delete CallMain;
309  if (Main->isDeclaration()) {
310  // The wrapper is not needed in this case as we don't need to export
311  // it to anyone else.
312  MainWrapper->eraseFromParent();
313  } else {
314  // Otherwise give the wrapper the same linkage as the original main
315  // function, so that it can be called from the same places.
316  MainWrapper->setName("main");
317  MainWrapper->setLinkage(Main->getLinkage());
318  MainWrapper->setVisibility(Main->getVisibility());
319  }
320  }
321 
322  return true;
323 }
bool isVarArg() const
isVarArg - Return true if this function takes a variable number of arguments.
Definition: Function.h:176
uint64_t CallInst * C
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks &#39;this&#39; from the containing basic block and deletes it.
Definition: Instruction.cpp:67
A parsed version of the target data layout string in and methods for querying it. ...
Definition: DataLayout.h:111
iterator_range< use_iterator > uses()
Definition: Value.h:374
This class represents an incoming formal argument to a Function.
Definition: Argument.h:29
This class represents lattice values for constants.
Definition: AllocatorList.h:23
A Module instance is used to store all the information related to an LLVM module. ...
Definition: Module.h:65
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
This class represents a function call, abstracting a target machine&#39;s calling convention.
static PointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space...
Definition: Type.cpp:632
Like Internal, but omit from symbol table.
Definition: GlobalValue.h:56
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:733
static CastInst * CreateBitOrPointerCast(Value *S, Type *Ty, const Twine &Name="", Instruction *InsertBefore=nullptr)
Create a BitCast, a PtrToInt, or an IntToPTr cast instruction.
const Use & getOperandUse(unsigned i) const
Definition: User.h:182
arg_iterator arg_end()
Definition: Function.h:704
F(f)
param_iterator param_end() const
Definition: DerivedTypes.h:129
This file contains the entry points for global functions defined in the LLVM WebAssembly back-end...
static bool isBitOrNoopPointerCastable(Type *SrcTy, Type *DestTy, const DataLayout &DL)
Check whether a bitcast, inttoptr, or ptrtoint cast between these types is valid and a no-op...
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:343
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, Instruction *InsertBefore=nullptr)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:195
amdgpu aa AMDGPU Address space based Alias Analysis Wrapper
LLVMContext & getContext() const
Get the global data context.
Definition: Module.h:244
A Use represents the edge between a Value definition and its users.
Definition: Use.h:55
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: APFloat.h:41
static void findUses(Value *V, Function &F, SmallVectorImpl< std::pair< Use *, Function *>> &Uses, SmallPtrSetImpl< Constant *> &ConstantBCs)
void setName(const Twine &Name)
Change the name of the value.
Definition: Value.cpp:285
Class to represent function types.
Definition: DerivedTypes.h:103
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:96
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:245
INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE, "Fix mismatching bitcasts for WebAssembly", false, false) ModulePass *llvm
bool isVarArg() const
Definition: DerivedTypes.h:123
LinkageTypes getLinkage() const
Definition: GlobalValue.h:460
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:429
static Constant * getBitCast(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:1804
bool isVoidTy() const
Return true if this is &#39;void&#39;.
Definition: Type.h:140
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
Definition: Function.h:135
VisibilityTypes getVisibility() const
Definition: GlobalValue.h:236
Value * getCalledValue() const
Definition: InstrTypes.h:1280
LLVM Basic Block Representation.
Definition: BasicBlock.h:57
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:45
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:64
This function has undefined behavior.
This file contains the declarations for the subclasses of Constant, which represent the different fla...
unsigned getNumParams() const
Return the number of fixed parameters this function type requires.
Definition: DerivedTypes.h:139
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:370
param_iterator param_begin() const
Definition: DerivedTypes.h:128
Represent the analysis usage information of a pass.
static Function * createWrapper(Function *F, FunctionType *Ty)
static FunctionType * get(Type *Result, ArrayRef< Type *> Params, bool isVarArg)
This static method is the primary way of constructing a FunctionType.
Definition: Type.cpp:296
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:99
arg_iterator arg_begin()
Definition: Function.h:695
static UndefValue * get(Type *T)
Static factory methods - Return an &#39;undef&#39; object of the specified type.
Definition: Constants.cpp:1446
const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs and address space casts.
Definition: Value.cpp:525
size_t size() const
Definition: SmallVector.h:52
static PointerType * getInt8PtrTy(LLVMContext &C, unsigned AS=0)
Definition: Type.cpp:219
ModulePass * createWebAssemblyFixFunctionBitcasts()
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements...
Definition: SmallPtrSet.h:417
This is a &#39;vector&#39; (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:837
Module.h This file contains the declarations for the Module class.
Type::subtype_iterator param_iterator
Definition: DerivedTypes.h:126
Type * getReturnType() const
Definition: DerivedTypes.h:124
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:301
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:132
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.h:163
amdgpu Simplify well known AMD library false FunctionCallee Callee
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:175
StringRef getName() const
Return a constant reference to the value&#39;s name.
Definition: Value.cpp:214
static bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy)
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:224
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:332
void eraseFromParent()
eraseFromParent - This method unlinks &#39;this&#39; from the containing module and deletes it...
Definition: Function.cpp:226
bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
Definition: Globals.cpp:231
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:575
LLVM Value Representation.
Definition: Value.h:73
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:48
#define LLVM_DEBUG(X)
Definition: Debug.h:122
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
PointerType * getType() const
Global values are always pointers.
Definition: GlobalValue.h:277
bool isStructTy() const
True if this is an instance of StructType.
Definition: Type.h:217