LLVM 19.0.0git
AArch64Arm64ECCallLowering.cpp
Go to the documentation of this file.
1//===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file contains the IR transform to lower external or indirect calls for
11/// the ARM64EC calling convention. Such calls must go through the runtime, so
12/// we can translate the calling convention for calls into the emulator.
13///
14/// This subsumes Control Flow Guard handling.
15///
16//===----------------------------------------------------------------------===//
17
18#include "AArch64.h"
19#include "llvm/ADT/SetVector.h"
22#include "llvm/ADT/Statistic.h"
23#include "llvm/IR/CallingConv.h"
24#include "llvm/IR/IRBuilder.h"
25#include "llvm/IR/Instruction.h"
26#include "llvm/IR/Mangler.h"
28#include "llvm/Object/COFF.h"
29#include "llvm/Pass.h"
32
33using namespace llvm;
34using namespace llvm::COFF;
35
37
38#define DEBUG_TYPE "arm64eccalllowering"
39
40STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered");
41
42static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect",
43 cl::Hidden, cl::init(true));
44static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden,
45 cl::init(true));
46
47namespace {
48
49class AArch64Arm64ECCallLowering : public ModulePass {
50public:
51 static char ID;
52 AArch64Arm64ECCallLowering() : ModulePass(ID) {
54 }
55
56 Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs);
57 Function *buildEntryThunk(Function *F);
58 void lowerCall(CallBase *CB);
59 Function *buildGuestExitThunk(Function *F);
60 bool processFunction(Function &F, SetVector<Function *> &DirectCalledFns);
61 bool runOnModule(Module &M) override;
62
63private:
64 int cfguard_module_flag = 0;
65 FunctionType *GuardFnType = nullptr;
66 PointerType *GuardFnPtrType = nullptr;
67 Constant *GuardFnCFGlobal = nullptr;
68 Constant *GuardFnGlobal = nullptr;
69 Module *M = nullptr;
70
71 Type *PtrTy;
72 Type *I64Ty;
73 Type *VoidTy;
74
75 void getThunkType(FunctionType *FT, AttributeList AttrList,
77 FunctionType *&Arm64Ty, FunctionType *&X64Ty);
78 void getThunkRetType(FunctionType *FT, AttributeList AttrList,
79 raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy,
80 SmallVectorImpl<Type *> &Arm64ArgTypes,
81 SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr);
82 void getThunkArgTypes(FunctionType *FT, AttributeList AttrList,
84 SmallVectorImpl<Type *> &Arm64ArgTypes,
85 SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr);
86 void canonicalizeThunkType(Type *T, Align Alignment, bool Ret,
87 uint64_t ArgSizeBytes, raw_ostream &Out,
88 Type *&Arm64Ty, Type *&X64Ty);
89};
90
91} // end anonymous namespace
92
93void AArch64Arm64ECCallLowering::getThunkType(
95 raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty) {
96 Out << (TT == Arm64ECThunkType::Entry ? "$ientry_thunk$cdecl$"
97 : "$iexit_thunk$cdecl$");
98
99 Type *Arm64RetTy;
100 Type *X64RetTy;
101
102 SmallVector<Type *> Arm64ArgTypes;
103 SmallVector<Type *> X64ArgTypes;
104
105 // The first argument to a thunk is the called function, stored in x9.
106 // For exit thunks, we pass the called function down to the emulator;
107 // for entry/guest exit thunks, we just call the Arm64 function directly.
108 if (TT == Arm64ECThunkType::Exit)
109 Arm64ArgTypes.push_back(PtrTy);
110 X64ArgTypes.push_back(PtrTy);
111
112 bool HasSretPtr = false;
113 getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
114 X64ArgTypes, HasSretPtr);
115
116 getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,
117 HasSretPtr);
118
119 Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false);
120
121 X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false);
122}
123
124void AArch64Arm64ECCallLowering::getThunkArgTypes(
126 raw_ostream &Out, SmallVectorImpl<Type *> &Arm64ArgTypes,
127 SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr) {
128
129 Out << "$";
130 if (FT->isVarArg()) {
131 // We treat the variadic function's thunk as a normal function
132 // with the following type on the ARM side:
133 // rettype exitthunk(
134 // ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5)
135 //
136 // that can coverage all types of variadic function.
137 // x9 is similar to normal exit thunk, store the called function.
138 // x0-x3 is the arguments be stored in registers.
139 // x4 is the address of the arguments on the stack.
140 // x5 is the size of the arguments on the stack.
141 //
142 // On the x64 side, it's the same except that x5 isn't set.
143 //
144 // If both the ARM and X64 sides are sret, there are only three
145 // arguments in registers.
146 //
147 // If the X64 side is sret, but the ARM side isn't, we pass an extra value
148 // to/from the X64 side, and let SelectionDAG transform it into a memory
149 // location.
150 Out << "varargs";
151
152 // x0-x3
153 for (int i = HasSretPtr ? 1 : 0; i < 4; i++) {
154 Arm64ArgTypes.push_back(I64Ty);
155 X64ArgTypes.push_back(I64Ty);
156 }
157
158 // x4
159 Arm64ArgTypes.push_back(PtrTy);
160 X64ArgTypes.push_back(PtrTy);
161 // x5
162 Arm64ArgTypes.push_back(I64Ty);
163 if (TT != Arm64ECThunkType::Entry) {
164 // FIXME: x5 isn't actually used by the x64 side; revisit once we
165 // have proper isel for varargs
166 X64ArgTypes.push_back(I64Ty);
167 }
168 return;
169 }
170
171 unsigned I = 0;
172 if (HasSretPtr)
173 I++;
174
175 if (I == FT->getNumParams()) {
176 Out << "v";
177 return;
178 }
179
180 for (unsigned E = FT->getNumParams(); I != E; ++I) {
181#if 0
182 // FIXME: Need more information about argument size; see
183 // https://reviews.llvm.org/D132926
184 uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I);
185 Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne();
186#else
187 uint64_t ArgSizeBytes = 0;
188 Align ParamAlign = Align();
189#endif
190 Type *Arm64Ty, *X64Ty;
191 canonicalizeThunkType(FT->getParamType(I), ParamAlign,
192 /*Ret*/ false, ArgSizeBytes, Out, Arm64Ty, X64Ty);
193 Arm64ArgTypes.push_back(Arm64Ty);
194 X64ArgTypes.push_back(X64Ty);
195 }
196}
197
198void AArch64Arm64ECCallLowering::getThunkRetType(
199 FunctionType *FT, AttributeList AttrList, raw_ostream &Out,
200 Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes,
201 SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr) {
202 Type *T = FT->getReturnType();
203#if 0
204 // FIXME: Need more information about argument size; see
205 // https://reviews.llvm.org/D132926
206 uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes();
207#else
208 int64_t ArgSizeBytes = 0;
209#endif
210 if (T->isVoidTy()) {
211 if (FT->getNumParams()) {
212 auto SRetAttr = AttrList.getParamAttr(0, Attribute::StructRet);
213 auto InRegAttr = AttrList.getParamAttr(0, Attribute::InReg);
214 if (SRetAttr.isValid() && InRegAttr.isValid()) {
215 // sret+inreg indicates a call that returns a C++ class value. This is
216 // actually equivalent to just passing and returning a void* pointer
217 // as the first argument. Translate it that way, instead of trying
218 // to model "inreg" in the thunk's calling convention, to simplify
219 // the rest of the code.
220 Out << "i8";
221 Arm64RetTy = I64Ty;
222 X64RetTy = I64Ty;
223 return;
224 }
225 if (SRetAttr.isValid()) {
226 // FIXME: Sanity-check the sret type; if it's an integer or pointer,
227 // we'll get screwy mangling/codegen.
228 // FIXME: For large struct types, mangle as an integer argument and
229 // integer return, so we can reuse more thunks, instead of "m" syntax.
230 // (MSVC mangles this case as an integer return with no argument, but
231 // that's a miscompile.)
232 Type *SRetType = SRetAttr.getValueAsType();
233 Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne();
234 Type *Arm64Ty, *X64Ty;
235 canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes,
236 Out, Arm64Ty, X64Ty);
237 Arm64RetTy = VoidTy;
238 X64RetTy = VoidTy;
239 Arm64ArgTypes.push_back(FT->getParamType(0));
240 X64ArgTypes.push_back(FT->getParamType(0));
241 HasSretPtr = true;
242 return;
243 }
244 }
245
246 Out << "v";
247 Arm64RetTy = VoidTy;
248 X64RetTy = VoidTy;
249 return;
250 }
251
252 canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out, Arm64RetTy,
253 X64RetTy);
254 if (X64RetTy->isPointerTy()) {
255 // If the X64 type is canonicalized to a pointer, that means it's
256 // passed/returned indirectly. For a return value, that means it's an
257 // sret pointer.
258 X64ArgTypes.push_back(X64RetTy);
259 X64RetTy = VoidTy;
260 }
261}
262
263void AArch64Arm64ECCallLowering::canonicalizeThunkType(
264 Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes, raw_ostream &Out,
265 Type *&Arm64Ty, Type *&X64Ty) {
266 if (T->isFloatTy()) {
267 Out << "f";
268 Arm64Ty = T;
269 X64Ty = T;
270 return;
271 }
272
273 if (T->isDoubleTy()) {
274 Out << "d";
275 Arm64Ty = T;
276 X64Ty = T;
277 return;
278 }
279
280 if (T->isFloatingPointTy()) {
282 "Only 32 and 64 bit floating points are supported for ARM64EC thunks");
283 }
284
285 auto &DL = M->getDataLayout();
286
287 if (auto *StructTy = dyn_cast<StructType>(T))
288 if (StructTy->getNumElements() == 1)
289 T = StructTy->getElementType(0);
290
291 if (T->isArrayTy()) {
292 Type *ElementTy = T->getArrayElementType();
293 uint64_t ElementCnt = T->getArrayNumElements();
294 uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8;
295 uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes;
296 if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) {
297 Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes;
298 if (Alignment.value() >= 16 && !Ret)
299 Out << "a" << Alignment.value();
300 Arm64Ty = T;
301 if (TotalSizeBytes <= 8) {
302 // Arm64 returns small structs of float/double in float registers;
303 // X64 uses RAX.
304 X64Ty = llvm::Type::getIntNTy(M->getContext(), TotalSizeBytes * 8);
305 } else {
306 // Struct is passed directly on Arm64, but indirectly on X64.
307 X64Ty = PtrTy;
308 }
309 return;
310 } else if (T->isFloatingPointTy()) {
311 report_fatal_error("Only 32 and 64 bit floating points are supported for "
312 "ARM64EC thunks");
313 }
314 }
315
316 if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) {
317 Out << "i8";
318 Arm64Ty = I64Ty;
319 X64Ty = I64Ty;
320 return;
321 }
322
323 unsigned TypeSize = ArgSizeBytes;
324 if (TypeSize == 0)
325 TypeSize = DL.getTypeSizeInBits(T) / 8;
326 Out << "m";
327 if (TypeSize != 4)
328 Out << TypeSize;
329 if (Alignment.value() >= 16 && !Ret)
330 Out << "a" << Alignment.value();
331 // FIXME: Try to canonicalize Arm64Ty more thoroughly?
332 Arm64Ty = T;
333 if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) {
334 // Pass directly in an integer register
335 X64Ty = llvm::Type::getIntNTy(M->getContext(), TypeSize * 8);
336 } else {
337 // Passed directly on Arm64, but indirectly on X64.
338 X64Ty = PtrTy;
339 }
340}
341
342// This function builds the "exit thunk", a function which translates
343// arguments and return values when calling x64 code from AArch64 code.
344Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
345 AttributeList Attrs) {
346 SmallString<256> ExitThunkName;
347 llvm::raw_svector_ostream ExitThunkStream(ExitThunkName);
348 FunctionType *Arm64Ty, *X64Ty;
349 getThunkType(FT, Attrs, Arm64ECThunkType::Exit, ExitThunkStream, Arm64Ty,
350 X64Ty);
351 if (Function *F = M->getFunction(ExitThunkName))
352 return F;
353
355 ExitThunkName, M);
356 F->setCallingConv(CallingConv::ARM64EC_Thunk_Native);
357 F->setSection(".wowthk$aa");
358 F->setComdat(M->getOrInsertComdat(ExitThunkName));
359 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
360 F->addFnAttr("frame-pointer", "all");
361 // Only copy sret from the first argument. For C++ instance methods, clang can
362 // stick an sret marking on a later argument, but it doesn't actually affect
363 // the ABI, so we can omit it. This avoids triggering a verifier assertion.
364 if (FT->getNumParams()) {
365 auto SRet = Attrs.getParamAttr(0, Attribute::StructRet);
366 auto InReg = Attrs.getParamAttr(0, Attribute::InReg);
367 if (SRet.isValid() && !InReg.isValid())
368 F->addParamAttr(1, SRet);
369 }
370 // FIXME: Copy anything other than sret? Shouldn't be necessary for normal
371 // C ABI, but might show up in other cases.
372 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F);
373 IRBuilder<> IRB(BB);
374 Value *CalleePtr =
375 M->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy);
376 Value *Callee = IRB.CreateLoad(PtrTy, CalleePtr);
377 auto &DL = M->getDataLayout();
379
380 // Pass the called function in x9.
381 Args.push_back(F->arg_begin());
382
383 Type *RetTy = Arm64Ty->getReturnType();
384 if (RetTy != X64Ty->getReturnType()) {
385 // If the return type is an array or struct, translate it. Values of size
386 // 8 or less go into RAX; bigger values go into memory, and we pass a
387 // pointer.
388 if (DL.getTypeStoreSize(RetTy) > 8) {
389 Args.push_back(IRB.CreateAlloca(RetTy));
390 }
391 }
392
393 for (auto &Arg : make_range(F->arg_begin() + 1, F->arg_end())) {
394 // Translate arguments from AArch64 calling convention to x86 calling
395 // convention.
396 //
397 // For simple types, we don't need to do any translation: they're
398 // represented the same way. (Implicit sign extension is not part of
399 // either convention.)
400 //
401 // The big thing we have to worry about is struct types... but
402 // fortunately AArch64 clang is pretty friendly here: the cases that need
403 // translation are always passed as a struct or array. (If we run into
404 // some cases where this doesn't work, we can teach clang to mark it up
405 // with an attribute.)
406 //
407 // The first argument is the called function, stored in x9.
408 if (Arg.getType()->isArrayTy() || Arg.getType()->isStructTy() ||
409 DL.getTypeStoreSize(Arg.getType()) > 8) {
410 Value *Mem = IRB.CreateAlloca(Arg.getType());
411 IRB.CreateStore(&Arg, Mem);
412 if (DL.getTypeStoreSize(Arg.getType()) <= 8) {
413 Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType()));
414 Args.push_back(IRB.CreateLoad(IntTy, IRB.CreateBitCast(Mem, PtrTy)));
415 } else
416 Args.push_back(Mem);
417 } else {
418 Args.push_back(&Arg);
419 }
420 }
421 // FIXME: Transfer necessary attributes? sret? anything else?
422
423 Callee = IRB.CreateBitCast(Callee, PtrTy);
424 CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args);
425 Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
426
427 Value *RetVal = Call;
428 if (RetTy != X64Ty->getReturnType()) {
429 // If we rewrote the return type earlier, convert the return value to
430 // the proper type.
431 if (DL.getTypeStoreSize(RetTy) > 8) {
432 RetVal = IRB.CreateLoad(RetTy, Args[1]);
433 } else {
434 Value *CastAlloca = IRB.CreateAlloca(RetTy);
435 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
436 RetVal = IRB.CreateLoad(RetTy, CastAlloca);
437 }
438 }
439
440 if (RetTy->isVoidTy())
441 IRB.CreateRetVoid();
442 else
443 IRB.CreateRet(RetVal);
444 return F;
445}
446
447// This function builds the "entry thunk", a function which translates
448// arguments and return values when calling AArch64 code from x64 code.
449Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
450 SmallString<256> EntryThunkName;
451 llvm::raw_svector_ostream EntryThunkStream(EntryThunkName);
452 FunctionType *Arm64Ty, *X64Ty;
453 getThunkType(F->getFunctionType(), F->getAttributes(),
454 Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty);
455 if (Function *F = M->getFunction(EntryThunkName))
456 return F;
457
459 EntryThunkName, M);
460 Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
461 Thunk->setSection(".wowthk$aa");
462 Thunk->setComdat(M->getOrInsertComdat(EntryThunkName));
463 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
464 Thunk->addFnAttr("frame-pointer", "all");
465
466 auto &DL = M->getDataLayout();
467 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk);
468 IRBuilder<> IRB(BB);
469
470 Type *RetTy = Arm64Ty->getReturnType();
471 Type *X64RetType = X64Ty->getReturnType();
472
473 bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy();
474 unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
475 unsigned PassthroughArgSize = F->isVarArg() ? 5 : Thunk->arg_size();
476
477 // Translate arguments to call.
479 for (unsigned i = ThunkArgOffset, e = PassthroughArgSize; i != e; ++i) {
480 Value *Arg = Thunk->getArg(i);
481 Type *ArgTy = Arm64Ty->getParamType(i - ThunkArgOffset);
482 if (ArgTy->isArrayTy() || ArgTy->isStructTy() ||
483 DL.getTypeStoreSize(ArgTy) > 8) {
484 // Translate array/struct arguments to the expected type.
485 if (DL.getTypeStoreSize(ArgTy) <= 8) {
486 Value *CastAlloca = IRB.CreateAlloca(ArgTy);
487 IRB.CreateStore(Arg, IRB.CreateBitCast(CastAlloca, PtrTy));
488 Arg = IRB.CreateLoad(ArgTy, CastAlloca);
489 } else {
490 Arg = IRB.CreateLoad(ArgTy, IRB.CreateBitCast(Arg, PtrTy));
491 }
492 }
493 Args.push_back(Arg);
494 }
495
496 if (F->isVarArg()) {
497 // The 5th argument to variadic entry thunks is used to model the x64 sp
498 // which is passed to the thunk in x4, this can be passed to the callee as
499 // the variadic argument start address after skipping over the 32 byte
500 // shadow store.
501
502 // The EC thunk CC will assign any argument marked as InReg to x4.
503 Thunk->addParamAttr(5, Attribute::InReg);
504 Value *Arg = Thunk->getArg(5);
505 Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20));
506 Args.push_back(Arg);
507
508 // Pass in a zero variadic argument size (in x5).
509 Args.push_back(IRB.getInt64(0));
510 }
511
512 // Call the function passed to the thunk.
513 Value *Callee = Thunk->getArg(0);
514 Callee = IRB.CreateBitCast(Callee, PtrTy);
515 CallInst *Call = IRB.CreateCall(Arm64Ty, Callee, Args);
516
517 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
518 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
519 if (SRetAttr.isValid() && !InRegAttr.isValid()) {
520 Thunk->addParamAttr(1, SRetAttr);
521 Call->addParamAttr(0, SRetAttr);
522 }
523
524 Value *RetVal = Call;
525 if (TransformDirectToSRet) {
526 IRB.CreateStore(RetVal, IRB.CreateBitCast(Thunk->getArg(1), PtrTy));
527 } else if (X64RetType != RetTy) {
528 Value *CastAlloca = IRB.CreateAlloca(X64RetType);
529 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
530 RetVal = IRB.CreateLoad(X64RetType, CastAlloca);
531 }
532
533 // Return to the caller. Note that the isel has code to translate this
534 // "ret" to a tail call to __os_arm64x_dispatch_ret. (Alternatively, we
535 // could emit a tail call here, but that would require a dedicated calling
536 // convention, which seems more complicated overall.)
537 if (X64RetType->isVoidTy())
538 IRB.CreateRetVoid();
539 else
540 IRB.CreateRet(RetVal);
541
542 return Thunk;
543}
544
545// Builds the "guest exit thunk", a helper to call a function which may or may
546// not be an exit thunk. (We optimistically assume non-dllimport function
547// declarations refer to functions defined in AArch64 code; if the linker
548// can't prove that, we use this routine instead.)
549Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
550 llvm::raw_null_ostream NullThunkName;
551 FunctionType *Arm64Ty, *X64Ty;
552 getThunkType(F->getFunctionType(), F->getAttributes(),
553 Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty);
554 auto MangledName = getArm64ECMangledFunctionName(F->getName().str());
555 assert(MangledName && "Can't guest exit to function that's already native");
556 std::string ThunkName = *MangledName;
557 if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
558 ThunkName.insert(ThunkName.find("@"), "$exit_thunk");
559 } else {
560 ThunkName.append("$exit_thunk");
561 }
563 Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
564 GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
565 GuestExit->setSection(".wowthk$aa");
566 GuestExit->setMetadata(
567 "arm64ec_unmangled_name",
568 MDNode::get(M->getContext(),
569 MDString::get(M->getContext(), F->getName())));
570 GuestExit->setMetadata(
571 "arm64ec_ecmangled_name",
572 MDNode::get(M->getContext(),
573 MDString::get(M->getContext(), *MangledName)));
574 F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {}));
575 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
576 IRBuilder<> B(BB);
577
578 // Load the global symbol as a pointer to the check function.
579 Value *GuardFn;
580 if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf"))
581 GuardFn = GuardFnCFGlobal;
582 else
583 GuardFn = GuardFnGlobal;
584 LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
585
586 // Create new call instruction. The CFGuard check should always be a call,
587 // even if the original CallBase is an Invoke or CallBr instruction.
588 Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes());
589 CallInst *GuardCheck = B.CreateCall(
590 GuardFnType, GuardCheckLoad,
591 {B.CreateBitCast(F, B.getPtrTy()), B.CreateBitCast(Thunk, B.getPtrTy())});
592
593 // Ensure that the first argument is passed in the correct register.
595
596 Value *GuardRetVal = B.CreateBitCast(GuardCheck, PtrTy);
598 for (Argument &Arg : GuestExit->args())
599 Args.push_back(&Arg);
600 CallInst *Call = B.CreateCall(Arm64Ty, GuardRetVal, Args);
601 Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
602
603 if (Call->getType()->isVoidTy())
604 B.CreateRetVoid();
605 else
606 B.CreateRet(Call);
607
608 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
609 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
610 if (SRetAttr.isValid() && !InRegAttr.isValid()) {
611 GuestExit->addParamAttr(0, SRetAttr);
612 Call->addParamAttr(0, SRetAttr);
613 }
614
615 return GuestExit;
616}
617
618// Lower an indirect call with inline code.
619void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
621 "Only applicable for Windows targets");
622
623 IRBuilder<> B(CB);
624 Value *CalledOperand = CB->getCalledOperand();
625
626 // If the indirect call is called within catchpad or cleanuppad,
627 // we need to copy "funclet" bundle of the call.
629 if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet))
630 Bundles.push_back(OperandBundleDef(*Bundle));
631
632 // Load the global symbol as a pointer to the check function.
633 Value *GuardFn;
634 if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf"))
635 GuardFn = GuardFnCFGlobal;
636 else
637 GuardFn = GuardFnGlobal;
638 LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
639
640 // Create new call instruction. The CFGuard check should always be a call,
641 // even if the original CallBase is an Invoke or CallBr instruction.
642 Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes());
643 CallInst *GuardCheck =
644 B.CreateCall(GuardFnType, GuardCheckLoad,
645 {B.CreateBitCast(CalledOperand, B.getPtrTy()),
646 B.CreateBitCast(Thunk, B.getPtrTy())},
647 Bundles);
648
649 // Ensure that the first argument is passed in the correct register.
651
652 Value *GuardRetVal = B.CreateBitCast(GuardCheck, CalledOperand->getType());
653 CB->setCalledOperand(GuardRetVal);
654}
655
656bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
657 if (!GenerateThunks)
658 return false;
659
660 M = &Mod;
661
662 // Check if this module has the cfguard flag and read its value.
663 if (auto *MD =
664 mdconst::extract_or_null<ConstantInt>(M->getModuleFlag("cfguard")))
665 cfguard_module_flag = MD->getZExtValue();
666
667 PtrTy = PointerType::getUnqual(M->getContext());
668 I64Ty = Type::getInt64Ty(M->getContext());
669 VoidTy = Type::getVoidTy(M->getContext());
670
671 GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
672 GuardFnPtrType = PointerType::get(GuardFnType, 0);
673 GuardFnCFGlobal =
674 M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType);
675 GuardFnGlobal =
676 M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType);
677
678 SetVector<Function *> DirectCalledFns;
679 for (Function &F : Mod)
680 if (!F.isDeclaration() &&
681 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
682 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)
683 processFunction(F, DirectCalledFns);
684
685 struct ThunkInfo {
686 Constant *Src;
687 Constant *Dst;
689 };
690 SmallVector<ThunkInfo> ThunkMapping;
691 for (Function &F : Mod) {
692 if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) &&
693 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
694 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) {
695 if (!F.hasComdat())
696 F.setComdat(Mod.getOrInsertComdat(F.getName()));
697 ThunkMapping.push_back(
698 {&F, buildEntryThunk(&F), Arm64ECThunkType::Entry});
699 }
700 }
701 for (Function *F : DirectCalledFns) {
702 ThunkMapping.push_back(
703 {F, buildExitThunk(F->getFunctionType(), F->getAttributes()),
704 Arm64ECThunkType::Exit});
705 if (!F->hasDLLImportStorageClass())
706 ThunkMapping.push_back(
707 {buildGuestExitThunk(F), F, Arm64ECThunkType::GuestExit});
708 }
709
710 if (!ThunkMapping.empty()) {
711 SmallVector<Constant *> ThunkMappingArrayElems;
712 for (ThunkInfo &Thunk : ThunkMapping) {
713 ThunkMappingArrayElems.push_back(ConstantStruct::getAnon(
714 {ConstantExpr::getBitCast(Thunk.Src, PtrTy),
716 ConstantInt::get(M->getContext(), APInt(32, uint8_t(Thunk.Kind)))}));
717 }
718 Constant *ThunkMappingArray = ConstantArray::get(
719 llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(),
720 ThunkMappingArrayElems.size()),
721 ThunkMappingArrayElems);
722 new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false,
723 GlobalValue::ExternalLinkage, ThunkMappingArray,
724 "llvm.arm64ec.symbolmap");
725 }
726
727 return true;
728}
729
730bool AArch64Arm64ECCallLowering::processFunction(
731 Function &F, SetVector<Function *> &DirectCalledFns) {
732 SmallVector<CallBase *, 8> IndirectCalls;
733
734 // For ARM64EC targets, a function definition's name is mangled differently
735 // from the normal symbol. We currently have no representation of this sort
736 // of symbol in IR, so we change the name to the mangled name, then store
737 // the unmangled name as metadata. Later passes that need the unmangled
738 // name (emitting the definition) can grab it from the metadata.
739 //
740 // FIXME: Handle functions with weak linkage?
741 if (!F.hasLocalLinkage() || F.hasAddressTaken()) {
742 if (std::optional<std::string> MangledName =
743 getArm64ECMangledFunctionName(F.getName().str())) {
744 F.setMetadata("arm64ec_unmangled_name",
745 MDNode::get(M->getContext(),
746 MDString::get(M->getContext(), F.getName())));
747 if (F.hasComdat() && F.getComdat()->getName() == F.getName()) {
748 Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value());
749 SmallVector<GlobalObject *> ComdatUsers =
750 to_vector(F.getComdat()->getUsers());
751 for (GlobalObject *User : ComdatUsers)
752 User->setComdat(MangledComdat);
753 }
754 F.setName(MangledName.value());
755 }
756 }
757
758 // Iterate over the instructions to find all indirect call/invoke/callbr
759 // instructions. Make a separate list of pointers to indirect
760 // call/invoke/callbr instructions because the original instructions will be
761 // deleted as the checks are added.
762 for (BasicBlock &BB : F) {
763 for (Instruction &I : BB) {
764 auto *CB = dyn_cast<CallBase>(&I);
766 CB->isInlineAsm())
767 continue;
768
769 // We need to instrument any call that isn't directly calling an
770 // ARM64 function.
771 //
772 // FIXME: getCalledFunction() fails if there's a bitcast (e.g.
773 // unprototyped functions in C)
774 if (Function *F = CB->getCalledFunction()) {
775 if (!LowerDirectToIndirect || F->hasLocalLinkage() ||
776 F->isIntrinsic() || !F->isDeclaration())
777 continue;
778
779 DirectCalledFns.insert(F);
780 continue;
781 }
782
783 IndirectCalls.push_back(CB);
784 ++Arm64ECCallsLowered;
785 }
786 }
787
788 if (IndirectCalls.empty())
789 return false;
790
791 for (CallBase *CB : IndirectCalls)
792 lowerCall(CB);
793
794 return true;
795}
796
797char AArch64Arm64ECCallLowering::ID = 0;
798INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering",
799 "AArch64Arm64ECCallLowering", false, false)
800
802 return new AArch64Arm64ECCallLowering;
803}
static cl::opt< bool > LowerDirectToIndirect("arm64ec-lower-direct-to-indirect", cl::Hidden, cl::init(true))
static cl::opt< bool > GenerateThunks("arm64ec-generate-thunks", cl::Hidden, cl::init(true))
OperandBundleDefT< Value * > OperandBundleDef
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
return RetTy
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
Module * Mod
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:38
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallString class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:40
Class for arbitrary precision integers.
Definition: APInt.h:76
This class represents an incoming formal argument to a Function.
Definition: Argument.h:31
static ArrayType * get(Type *ElementType, uint64_t NumElements)
This static method is the primary way to construct an ArrayType.
Definition: Type.cpp:647
Attribute getParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) const
Return the attribute object that exists at the arg index.
Definition: Attributes.h:837
MaybeAlign getParamAlignment(unsigned ArgNo) const
Return the alignment for the specified function parameter.
LLVM Basic Block Representation.
Definition: BasicBlock.h:60
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:199
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Definition: InstrTypes.h:1494
bool isInlineAsm() const
Check if this call is an inline asm statement.
Definition: InstrTypes.h:1809
void setCallingConv(CallingConv::ID CC)
Definition: InstrTypes.h:1804
std::optional< OperandBundleUse > getOperandBundle(StringRef Name) const
Return an operand bundle by name, if present.
Definition: InstrTypes.h:2411
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1742
bool hasFnAttr(Attribute::AttrKind Kind) const
Determine whether this call has the given attribute.
Definition: InstrTypes.h:1828
CallingConv::ID getCallingConv() const
Definition: InstrTypes.h:1800
Value * getCalledOperand() const
Definition: InstrTypes.h:1735
FunctionType * getFunctionType() const
Definition: InstrTypes.h:1600
void setCalledOperand(Value *V)
Definition: InstrTypes.h:1778
AttributeList getAttributes() const
Return the parameter attributes for this call.
Definition: InstrTypes.h:1819
This class represents a function call, abstracting a target machine's calling convention.
static Constant * get(ArrayType *T, ArrayRef< Constant * > V)
Definition: Constants.cpp:1291
static Constant * getBitCast(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2140
static Constant * getAnon(ArrayRef< Constant * > V, bool Packed=false)
Return an anonymous struct that has the specified elements.
Definition: Constants.h:476
This is an important base class in LLVM.
Definition: Constant.h:41
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
Definition: Function.h:164
@ WeakODRLinkage
Same, but only replaced by something equivalent.
Definition: GlobalValue.h:56
@ ExternalLinkage
Externally visible function.
Definition: GlobalValue.h:51
@ LinkOnceODRLinkage
Same, but only replaced by something equivalent.
Definition: GlobalValue.h:54
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2666
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
Definition: Instruction.cpp:83
An instruction for reading from memory.
Definition: Instructions.h:184
static MDTuple * get(LLVMContext &Context, ArrayRef< Metadata * > MDs)
Definition: Metadata.h:1541
static MDString * get(LLVMContext &Context, StringRef Str)
Definition: Metadata.cpp:600
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:251
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
const std::string & getTargetTriple() const
Get the target triple which is a string describing the target host.
Definition: Module.h:297
Comdat * getOrInsertComdat(StringRef Name)
Return the Comdat in the module with the specified name.
Definition: Module.cpp:589
A container for an operand bundle being viewed as a set of values rather than a set of uses.
Definition: InstrTypes.h:1447
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
A vector that has set insertion semantics.
Definition: SetVector.h:57
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
SmallString - A SmallString is just a SmallVector with methods and accessors that make it work better...
Definition: SmallString.h:26
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:586
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
Triple - Helper class for working with autoconf configuration names.
Definition: Triple.h:44
bool isOSWindows() const
Tests whether the OS is Windows.
Definition: Triple.h:624
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isArrayTy() const
True if this is an instance of ArrayType.
Definition: Type.h:252
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:255
bool isFloatTy() const
Return true if this is 'float', a 32-bit IEEE fp type.
Definition: Type.h:154
static IntegerType * getIntNTy(LLVMContext &C, unsigned N)
static Type * getVoidTy(LLVMContext &C)
bool isStructTy() const
True if this is an instance of StructType.
Definition: Type.h:249
bool isDoubleTy() const
Return true if this is 'double', a 64-bit IEEE fp type.
Definition: Type.h:157
static IntegerType * getInt64Ty(LLVMContext &C)
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:140
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
A raw_ostream that discards all output.
Definition: raw_ostream.h:723
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
A raw_ostream that writes to an SmallVector or SmallString.
Definition: raw_ostream.h:690
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
constexpr char Attrs[]
Key for Kernel::Metadata::mAttrs.
Arm64ECThunkType
Definition: COFF.h:809
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ ARM64EC_Thunk_Native
Calling convention used in the ARM64EC ABI to implement calls between ARM64 code and thunks.
Definition: CallingConv.h:265
@ CFGuard_Check
Special calling convention on Windows for calling the Control Guard Check ICall funtion.
Definition: CallingConv.h:82
@ ARM64EC_Thunk_X64
Calling convention used in the ARM64EC ABI to implement calls between x64 code and thunks.
Definition: CallingConv.h:260
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:450
constexpr double e
Definition: MathExtras.h:31
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
std::optional< std::string > getArm64ECMangledFunctionName(StringRef Name)
Definition: Mangler.cpp:293
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
void initializeAArch64Arm64ECCallLoweringPass(PassRegistry &)
ModulePass * createAArch64Arm64ECCallLoweringPass()
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Definition: Error.cpp:156
SmallVector< ValueTypeFromRangeType< R >, Size > to_vector(R &&Range)
Given a range of type R, iterate the entire range and return a SmallVector with elements of the vecto...
Definition: SmallVector.h:1312
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
uint64_t value() const
This is a hole in the type system and should not be abused.
Definition: Alignment.h:85
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.
Definition: Alignment.h:141