38#define DEBUG_TYPE "arm64eccalllowering"
40STATISTIC(Arm64ECCallsLowered,
"Number of Arm64EC calls lowered");
49class AArch64Arm64ECCallLowering :
public ModulePass {
64 int cfguard_module_flag = 0;
86 void canonicalizeThunkType(
Type *
T,
Align Alignment,
bool Ret,
93void AArch64Arm64ECCallLowering::getThunkType(
96 Out << (
TT == Arm64ECThunkType::Entry ?
"$ientry_thunk$cdecl$"
97 :
"$iexit_thunk$cdecl$");
108 if (TT == Arm64ECThunkType::Exit)
112 bool HasSretPtr =
false;
113 getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
114 X64ArgTypes, HasSretPtr);
116 getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,
119 Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes,
false);
121 X64Ty = FunctionType::get(X64RetTy, X64ArgTypes,
false);
124void AArch64Arm64ECCallLowering::getThunkArgTypes(
130 if (FT->isVarArg()) {
153 for (
int i = HasSretPtr ? 1 : 0; i < 4; i++) {
163 if (TT != Arm64ECThunkType::Entry) {
175 if (
I == FT->getNumParams()) {
180 for (
unsigned E = FT->getNumParams();
I != E; ++
I) {
184 uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(
I);
190 Type *Arm64Ty, *X64Ty;
191 canonicalizeThunkType(FT->getParamType(
I), ParamAlign,
192 false, ArgSizeBytes, Out, Arm64Ty, X64Ty);
198void AArch64Arm64ECCallLowering::getThunkRetType(
202 Type *
T = FT->getReturnType();
206 uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes();
208 int64_t ArgSizeBytes = 0;
211 if (FT->getNumParams()) {
212 auto SRetAttr = AttrList.
getParamAttr(0, Attribute::StructRet);
213 auto InRegAttr = AttrList.
getParamAttr(0, Attribute::InReg);
214 if (SRetAttr.isValid() && InRegAttr.isValid()) {
225 if (SRetAttr.isValid()) {
232 Type *SRetType = SRetAttr.getValueAsType();
234 Type *Arm64Ty, *X64Ty;
235 canonicalizeThunkType(SRetType, SRetAlign,
true, ArgSizeBytes,
236 Out, Arm64Ty, X64Ty);
239 Arm64ArgTypes.
push_back(FT->getParamType(0));
240 X64ArgTypes.
push_back(FT->getParamType(0));
252 canonicalizeThunkType(
T,
Align(),
true, ArgSizeBytes, Out, Arm64RetTy,
263void AArch64Arm64ECCallLowering::canonicalizeThunkType(
266 if (
T->isFloatTy()) {
273 if (
T->isDoubleTy()) {
280 if (
T->isFloatingPointTy()) {
282 "Only 32 and 64 bit floating points are supported for ARM64EC thunks");
285 auto &
DL =
M->getDataLayout();
287 if (
auto *StructTy = dyn_cast<StructType>(
T))
288 if (StructTy->getNumElements() == 1)
289 T = StructTy->getElementType(0);
291 if (
T->isArrayTy()) {
292 Type *ElementTy =
T->getArrayElementType();
293 uint64_t ElementCnt =
T->getArrayNumElements();
294 uint64_t ElementSizePerBytes =
DL.getTypeSizeInBits(ElementTy) / 8;
295 uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes;
297 Out << (ElementTy->
isFloatTy() ?
"F" :
"D") << TotalSizeBytes;
298 if (Alignment.
value() >= 16 && !
Ret)
299 Out <<
"a" << Alignment.
value();
301 if (TotalSizeBytes <= 8) {
310 }
else if (
T->isFloatingPointTy()) {
316 if ((
T->isIntegerTy() ||
T->isPointerTy()) &&
DL.getTypeSizeInBits(
T) <= 64) {
329 if (Alignment.
value() >= 16 && !Ret)
330 Out <<
"a" << Alignment.
value();
349 getThunkType(FT, Attrs, Arm64ECThunkType::Exit, ExitThunkStream, Arm64Ty,
351 if (
Function *
F =
M->getFunction(ExitThunkName))
357 F->setSection(
".wowthk$aa");
358 F->setComdat(
M->getOrInsertComdat(ExitThunkName));
360 F->addFnAttr(
"frame-pointer",
"all");
364 if (FT->getNumParams()) {
365 auto SRet =
Attrs.getParamAttr(0, Attribute::StructRet);
366 auto InReg =
Attrs.getParamAttr(0, Attribute::InReg);
367 if (SRet.isValid() && !InReg.isValid())
368 F->addParamAttr(1, SRet);
375 M->getOrInsertGlobal(
"__os_arm64x_dispatch_call_no_redirect", PtrTy);
377 auto &
DL =
M->getDataLayout();
381 Args.push_back(
F->arg_begin());
384 if (
RetTy != X64Ty->getReturnType()) {
388 if (
DL.getTypeStoreSize(
RetTy) > 8) {
393 for (
auto &Arg :
make_range(
F->arg_begin() + 1,
F->arg_end())) {
408 if (Arg.getType()->isArrayTy() || Arg.getType()->isStructTy() ||
409 DL.getTypeStoreSize(Arg.getType()) > 8) {
410 Value *Mem = IRB.CreateAlloca(Arg.getType());
411 IRB.CreateStore(&Arg, Mem);
412 if (
DL.getTypeStoreSize(Arg.getType()) <= 8) {
414 Args.push_back(IRB.CreateLoad(IntTy, IRB.CreateBitCast(Mem, PtrTy)));
418 Args.push_back(&Arg);
423 Callee = IRB.CreateBitCast(Callee, PtrTy);
428 if (
RetTy != X64Ty->getReturnType()) {
431 if (
DL.getTypeStoreSize(
RetTy) > 8) {
432 RetVal = IRB.CreateLoad(
RetTy, Args[1]);
435 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
436 RetVal = IRB.CreateLoad(
RetTy, CastAlloca);
440 if (
RetTy->isVoidTy())
443 IRB.CreateRet(RetVal);
453 getThunkType(
F->getFunctionType(),
F->getAttributes(),
454 Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty);
455 if (
Function *
F =
M->getFunction(EntryThunkName))
461 Thunk->setSection(
".wowthk$aa");
462 Thunk->setComdat(
M->getOrInsertComdat(EntryThunkName));
464 Thunk->addFnAttr(
"frame-pointer",
"all");
466 auto &
DL =
M->getDataLayout();
471 Type *X64RetType = X64Ty->getReturnType();
473 bool TransformDirectToSRet = X64RetType->
isVoidTy() && !
RetTy->isVoidTy();
474 unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
475 unsigned PassthroughArgSize =
F->isVarArg() ? 5 :
Thunk->arg_size();
479 for (
unsigned i = ThunkArgOffset, e = PassthroughArgSize; i !=
e; ++i) {
481 Type *ArgTy = Arm64Ty->getParamType(i - ThunkArgOffset);
483 DL.getTypeStoreSize(ArgTy) > 8) {
485 if (
DL.getTypeStoreSize(ArgTy) <= 8) {
486 Value *CastAlloca = IRB.CreateAlloca(ArgTy);
487 IRB.CreateStore(Arg, IRB.CreateBitCast(CastAlloca, PtrTy));
488 Arg = IRB.CreateLoad(ArgTy, CastAlloca);
490 Arg = IRB.CreateLoad(ArgTy, IRB.CreateBitCast(Arg, PtrTy));
503 Thunk->addParamAttr(5, Attribute::InReg);
505 Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20));
509 Args.push_back(IRB.getInt64(0));
514 Callee = IRB.CreateBitCast(Callee, PtrTy);
515 Value *
Call = IRB.CreateCall(Arm64Ty, Callee, Args);
518 if (TransformDirectToSRet) {
519 IRB.CreateStore(RetVal, IRB.CreateBitCast(
Thunk->getArg(1), PtrTy));
520 }
else if (X64RetType !=
RetTy) {
521 Value *CastAlloca = IRB.CreateAlloca(X64RetType);
522 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
523 RetVal = IRB.CreateLoad(X64RetType, CastAlloca);
533 IRB.CreateRet(RetVal);
545 getThunkType(
F->getFunctionType(),
F->getAttributes(),
546 Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty);
548 assert(MangledName &&
"Can't guest exit to function that's already native");
549 std::string ThunkName = *MangledName;
550 if (ThunkName[0] ==
'?' && ThunkName.find(
"@") != std::string::npos) {
551 ThunkName.insert(ThunkName.find(
"@"),
"$exit_thunk");
553 ThunkName.append(
"$exit_thunk");
557 GuestExit->setComdat(
M->getOrInsertComdat(ThunkName));
560 "arm64ec_unmangled_name",
564 "arm64ec_ecmangled_name",
567 F->setMetadata(
"arm64ec_hasguestexit",
MDNode::get(
M->getContext(), {}));
573 if (cfguard_module_flag == 2 && !
F->hasFnAttribute(
"guard_nocf"))
574 GuardFn = GuardFnCFGlobal;
576 GuardFn = GuardFnGlobal;
577 LoadInst *GuardCheckLoad =
B.CreateLoad(GuardFnPtrType, GuardFn);
581 Function *
Thunk = buildExitThunk(
F->getFunctionType(),
F->getAttributes());
583 GuardFnType, GuardCheckLoad,
584 {
B.CreateBitCast(
F,
B.getPtrTy()),
B.CreateBitCast(Thunk,
B.getPtrTy())});
589 Value *GuardRetVal =
B.CreateBitCast(GuardCheck, PtrTy);
592 Args.push_back(&Arg);
596 if (
Call->getType()->isVoidTy())
601 auto SRetAttr =
F->getAttributes().getParamAttr(0, Attribute::StructRet);
602 auto InRegAttr =
F->getAttributes().getParamAttr(0, Attribute::InReg);
603 if (SRetAttr.isValid() && !InRegAttr.isValid()) {
605 Call->addParamAttr(0, SRetAttr);
612void AArch64Arm64ECCallLowering::lowerCall(
CallBase *CB) {
614 "Only applicable for Windows targets");
627 if (cfguard_module_flag == 2 && !CB->
hasFnAttr(
"guard_nocf"))
628 GuardFn = GuardFnCFGlobal;
630 GuardFn = GuardFnGlobal;
631 LoadInst *GuardCheckLoad =
B.CreateLoad(GuardFnPtrType, GuardFn);
637 B.CreateCall(GuardFnType, GuardCheckLoad,
638 {
B.CreateBitCast(CalledOperand,
B.getPtrTy()),
639 B.CreateBitCast(Thunk,
B.getPtrTy())},
645 Value *GuardRetVal =
B.CreateBitCast(GuardCheck, CalledOperand->
getType());
649bool AArch64Arm64ECCallLowering::runOnModule(
Module &
Mod) {
657 mdconst::extract_or_null<ConstantInt>(
M->getModuleFlag(
"cfguard")))
658 cfguard_module_flag = MD->getZExtValue();
660 PtrTy = PointerType::getUnqual(
M->getContext());
664 GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy},
false);
665 GuardFnPtrType = PointerType::get(GuardFnType, 0);
667 M->getOrInsertGlobal(
"__os_arm64x_check_icall_cfg", GuardFnPtrType);
669 M->getOrInsertGlobal(
"__os_arm64x_check_icall", GuardFnPtrType);
673 if (!
F.isDeclaration() &&
676 processFunction(
F, DirectCalledFns);
685 if (!
F.isDeclaration() && (!
F.hasLocalLinkage() ||
F.hasAddressTaken()) &&
691 {&
F, buildEntryThunk(&
F), Arm64ECThunkType::Entry});
696 {
F, buildExitThunk(
F->getFunctionType(),
F->getAttributes()),
697 Arm64ECThunkType::Exit});
698 if (!
F->hasDLLImportStorageClass())
700 {buildGuestExitThunk(
F),
F, Arm64ECThunkType::GuestExit});
703 if (!ThunkMapping.
empty()) {
705 for (ThunkInfo &Thunk : ThunkMapping) {
709 ConstantInt::get(
M->getContext(),
APInt(32, uint8_t(
Thunk.Kind)))}));
713 ThunkMappingArrayElems.
size()),
714 ThunkMappingArrayElems);
717 "llvm.arm64ec.symbolmap");
723bool AArch64Arm64ECCallLowering::processFunction(
734 if (!
F.hasLocalLinkage() ||
F.hasAddressTaken()) {
735 if (std::optional<std::string> MangledName =
737 F.setMetadata(
"arm64ec_unmangled_name",
740 if (
F.hasComdat() &&
F.getComdat()->getName() ==
F.getName()) {
741 Comdat *MangledComdat =
M->getOrInsertComdat(MangledName.value());
745 User->setComdat(MangledComdat);
747 F.setName(MangledName.value());
757 auto *CB = dyn_cast<CallBase>(&
I);
769 F->isIntrinsic() || !
F->isDeclaration())
777 ++Arm64ECCallsLowered;
781 if (IndirectCalls.
empty())
790char AArch64Arm64ECCallLowering::ID = 0;
792 "AArch64Arm64ECCallLowering",
false,
false)
795 return new AArch64Arm64ECCallLowering;
static cl::opt< bool > LowerDirectToIndirect("arm64ec-lower-direct-to-indirect", cl::Hidden, cl::init(true))
static cl::opt< bool > GenerateThunks("arm64ec-generate-thunks", cl::Hidden, cl::init(true))
OperandBundleDefT< Value * > OperandBundleDef
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallString class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
static SymbolRef::Type getType(const Symbol *Sym)
Class for arbitrary precision integers.
This class represents an incoming formal argument to a Function.
static ArrayType * get(Type *ElementType, uint64_t NumElements)
This static method is the primary way to construct an ArrayType.
Attribute getParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) const
Return the attribute object that exists at the arg index.
MaybeAlign getParamAlignment(unsigned ArgNo) const
Return the alignment for the specified function parameter.
LLVM Basic Block Representation.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
bool isInlineAsm() const
Check if this call is an inline asm statement.
void setCallingConv(CallingConv::ID CC)
std::optional< OperandBundleUse > getOperandBundle(StringRef Name) const
Return an operand bundle by name, if present.
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
bool hasFnAttr(Attribute::AttrKind Kind) const
Determine whether this call has the given attribute.
CallingConv::ID getCallingConv() const
Value * getCalledOperand() const
FunctionType * getFunctionType() const
void setCalledOperand(Value *V)
AttributeList getAttributes() const
Return the parameter attributes for this call.
This class represents a function call, abstracting a target machine's calling convention.
static Constant * get(ArrayType *T, ArrayRef< Constant * > V)
static Constant * getBitCast(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static Constant * getAnon(ArrayRef< Constant * > V, bool Packed=false)
Return an anonymous struct that has the specified elements.
This is an important base class in LLVM.
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
@ WeakODRLinkage
Same, but only replaced by something equivalent.
@ ExternalLinkage
Externally visible function.
@ LinkOnceODRLinkage
Same, but only replaced by something equivalent.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
An instruction for reading from memory.
static MDTuple * get(LLVMContext &Context, ArrayRef< Metadata * > MDs)
static MDString * get(LLVMContext &Context, StringRef Str)
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
const std::string & getTargetTriple() const
Get the target triple which is a string describing the target host.
Comdat * getOrInsertComdat(StringRef Name)
Return the Comdat in the module with the specified name.
A container for an operand bundle being viewed as a set of values rather than a set of uses.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
A vector that has set insertion semantics.
bool insert(const value_type &X)
Insert a new element into the SetVector.
SmallString - A SmallString is just a SmallVector with methods and accessors that make it work better...
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Triple - Helper class for working with autoconf configuration names.
bool isOSWindows() const
Tests whether the OS is Windows.
The instances of the Type class are immutable: once they are created, they are never changed.
bool isArrayTy() const
True if this is an instance of ArrayType.
bool isPointerTy() const
True if this is an instance of PointerType.
bool isFloatTy() const
Return true if this is 'float', a 32-bit IEEE fp type.
static IntegerType * getIntNTy(LLVMContext &C, unsigned N)
static Type * getVoidTy(LLVMContext &C)
bool isStructTy() const
True if this is an instance of StructType.
bool isDoubleTy() const
Return true if this is 'double', a 64-bit IEEE fp type.
static IntegerType * getInt64Ty(LLVMContext &C)
bool isVoidTy() const
Return true if this is 'void'.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
A raw_ostream that discards all output.
This class implements an extremely fast bulk output stream that can only output to a stream.
A raw_ostream that writes to an SmallVector or SmallString.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
constexpr char Attrs[]
Key for Kernel::Metadata::mAttrs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ ARM64EC_Thunk_Native
Calling convention used in the ARM64EC ABI to implement calls between ARM64 code and thunks.
@ CFGuard_Check
Special calling convention on Windows for calling the Control Guard Check ICall funtion.
@ ARM64EC_Thunk_X64
Calling convention used in the ARM64EC ABI to implement calls between x64 code and thunks.
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
std::optional< std::string > getArm64ECMangledFunctionName(StringRef Name)
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
void initializeAArch64Arm64ECCallLoweringPass(PassRegistry &)
ModulePass * createAArch64Arm64ECCallLoweringPass()
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
SmallVector< ValueTypeFromRangeType< R >, Size > to_vector(R &&Range)
Given a range of type R, iterate the entire range and return a SmallVector with elements of the vecto...
This struct is a compact representation of a valid (non-zero power of two) alignment.
uint64_t value() const
This is a hole in the type system and should not be abused.
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.