LLVM 19.0.0git
AArch64Arm64ECCallLowering.cpp
Go to the documentation of this file.
1//===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file contains the IR transform to lower external or indirect calls for
11/// the ARM64EC calling convention. Such calls must go through the runtime, so
12/// we can translate the calling convention for calls into the emulator.
13///
14/// This subsumes Control Flow Guard handling.
15///
16//===----------------------------------------------------------------------===//
17
18#include "AArch64.h"
19#include "llvm/ADT/SetVector.h"
22#include "llvm/ADT/Statistic.h"
23#include "llvm/IR/CallingConv.h"
24#include "llvm/IR/IRBuilder.h"
25#include "llvm/IR/Instruction.h"
26#include "llvm/IR/Mangler.h"
28#include "llvm/Object/COFF.h"
29#include "llvm/Pass.h"
32
33using namespace llvm;
34using namespace llvm::COFF;
35
37
38#define DEBUG_TYPE "arm64eccalllowering"
39
40STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered");
41
42static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect",
43 cl::Hidden, cl::init(true));
44static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden,
45 cl::init(true));
46
47namespace {
48
49class AArch64Arm64ECCallLowering : public ModulePass {
50public:
51 static char ID;
52 AArch64Arm64ECCallLowering() : ModulePass(ID) {
54 }
55
56 Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs);
57 Function *buildEntryThunk(Function *F);
58 void lowerCall(CallBase *CB);
59 Function *buildGuestExitThunk(Function *F);
60 bool processFunction(Function &F, SetVector<Function *> &DirectCalledFns);
61 bool runOnModule(Module &M) override;
62
63private:
64 int cfguard_module_flag = 0;
65 FunctionType *GuardFnType = nullptr;
66 PointerType *GuardFnPtrType = nullptr;
67 Constant *GuardFnCFGlobal = nullptr;
68 Constant *GuardFnGlobal = nullptr;
69 Module *M = nullptr;
70
71 Type *PtrTy;
72 Type *I64Ty;
73 Type *VoidTy;
74
75 void getThunkType(FunctionType *FT, AttributeList AttrList,
77 FunctionType *&Arm64Ty, FunctionType *&X64Ty);
78 void getThunkRetType(FunctionType *FT, AttributeList AttrList,
79 raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy,
80 SmallVectorImpl<Type *> &Arm64ArgTypes,
81 SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr);
82 void getThunkArgTypes(FunctionType *FT, AttributeList AttrList,
84 SmallVectorImpl<Type *> &Arm64ArgTypes,
85 SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr);
86 void canonicalizeThunkType(Type *T, Align Alignment, bool Ret,
87 uint64_t ArgSizeBytes, raw_ostream &Out,
88 Type *&Arm64Ty, Type *&X64Ty);
89};
90
91} // end anonymous namespace
92
93void AArch64Arm64ECCallLowering::getThunkType(
95 raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty) {
96 Out << (TT == Arm64ECThunkType::Entry ? "$ientry_thunk$cdecl$"
97 : "$iexit_thunk$cdecl$");
98
99 Type *Arm64RetTy;
100 Type *X64RetTy;
101
102 SmallVector<Type *> Arm64ArgTypes;
103 SmallVector<Type *> X64ArgTypes;
104
105 // The first argument to a thunk is the called function, stored in x9.
106 // For exit thunks, we pass the called function down to the emulator;
107 // for entry/guest exit thunks, we just call the Arm64 function directly.
108 if (TT == Arm64ECThunkType::Exit)
109 Arm64ArgTypes.push_back(PtrTy);
110 X64ArgTypes.push_back(PtrTy);
111
112 bool HasSretPtr = false;
113 getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
114 X64ArgTypes, HasSretPtr);
115
116 getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,
117 HasSretPtr);
118
119 Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false);
120
121 X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false);
122}
123
124void AArch64Arm64ECCallLowering::getThunkArgTypes(
126 raw_ostream &Out, SmallVectorImpl<Type *> &Arm64ArgTypes,
127 SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr) {
128
129 Out << "$";
130 if (FT->isVarArg()) {
131 // We treat the variadic function's thunk as a normal function
132 // with the following type on the ARM side:
133 // rettype exitthunk(
134 // ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5)
135 //
136 // that can coverage all types of variadic function.
137 // x9 is similar to normal exit thunk, store the called function.
138 // x0-x3 is the arguments be stored in registers.
139 // x4 is the address of the arguments on the stack.
140 // x5 is the size of the arguments on the stack.
141 //
142 // On the x64 side, it's the same except that x5 isn't set.
143 //
144 // If both the ARM and X64 sides are sret, there are only three
145 // arguments in registers.
146 //
147 // If the X64 side is sret, but the ARM side isn't, we pass an extra value
148 // to/from the X64 side, and let SelectionDAG transform it into a memory
149 // location.
150 Out << "varargs";
151
152 // x0-x3
153 for (int i = HasSretPtr ? 1 : 0; i < 4; i++) {
154 Arm64ArgTypes.push_back(I64Ty);
155 X64ArgTypes.push_back(I64Ty);
156 }
157
158 // x4
159 Arm64ArgTypes.push_back(PtrTy);
160 X64ArgTypes.push_back(PtrTy);
161 // x5
162 Arm64ArgTypes.push_back(I64Ty);
163 if (TT != Arm64ECThunkType::Entry) {
164 // FIXME: x5 isn't actually used by the x64 side; revisit once we
165 // have proper isel for varargs
166 X64ArgTypes.push_back(I64Ty);
167 }
168 return;
169 }
170
171 unsigned I = 0;
172 if (HasSretPtr)
173 I++;
174
175 if (I == FT->getNumParams()) {
176 Out << "v";
177 return;
178 }
179
180 for (unsigned E = FT->getNumParams(); I != E; ++I) {
181#if 0
182 // FIXME: Need more information about argument size; see
183 // https://reviews.llvm.org/D132926
184 uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I);
185 Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne();
186#else
187 uint64_t ArgSizeBytes = 0;
188 Align ParamAlign = Align();
189#endif
190 Type *Arm64Ty, *X64Ty;
191 canonicalizeThunkType(FT->getParamType(I), ParamAlign,
192 /*Ret*/ false, ArgSizeBytes, Out, Arm64Ty, X64Ty);
193 Arm64ArgTypes.push_back(Arm64Ty);
194 X64ArgTypes.push_back(X64Ty);
195 }
196}
197
198void AArch64Arm64ECCallLowering::getThunkRetType(
199 FunctionType *FT, AttributeList AttrList, raw_ostream &Out,
200 Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes,
201 SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr) {
202 Type *T = FT->getReturnType();
203#if 0
204 // FIXME: Need more information about argument size; see
205 // https://reviews.llvm.org/D132926
206 uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes();
207#else
208 int64_t ArgSizeBytes = 0;
209#endif
210 if (T->isVoidTy()) {
211 if (FT->getNumParams()) {
212 auto SRetAttr = AttrList.getParamAttr(0, Attribute::StructRet);
213 auto InRegAttr = AttrList.getParamAttr(0, Attribute::InReg);
214 if (SRetAttr.isValid() && InRegAttr.isValid()) {
215 // sret+inreg indicates a call that returns a C++ class value. This is
216 // actually equivalent to just passing and returning a void* pointer
217 // as the first argument. Translate it that way, instead of trying
218 // to model "inreg" in the thunk's calling convention, to simplify
219 // the rest of the code.
220 Out << "i8";
221 Arm64RetTy = I64Ty;
222 X64RetTy = I64Ty;
223 return;
224 }
225 if (SRetAttr.isValid()) {
226 // FIXME: Sanity-check the sret type; if it's an integer or pointer,
227 // we'll get screwy mangling/codegen.
228 // FIXME: For large struct types, mangle as an integer argument and
229 // integer return, so we can reuse more thunks, instead of "m" syntax.
230 // (MSVC mangles this case as an integer return with no argument, but
231 // that's a miscompile.)
232 Type *SRetType = SRetAttr.getValueAsType();
233 Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne();
234 Type *Arm64Ty, *X64Ty;
235 canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes,
236 Out, Arm64Ty, X64Ty);
237 Arm64RetTy = VoidTy;
238 X64RetTy = VoidTy;
239 Arm64ArgTypes.push_back(FT->getParamType(0));
240 X64ArgTypes.push_back(FT->getParamType(0));
241 HasSretPtr = true;
242 return;
243 }
244 }
245
246 Out << "v";
247 Arm64RetTy = VoidTy;
248 X64RetTy = VoidTy;
249 return;
250 }
251
252 canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out, Arm64RetTy,
253 X64RetTy);
254 if (X64RetTy->isPointerTy()) {
255 // If the X64 type is canonicalized to a pointer, that means it's
256 // passed/returned indirectly. For a return value, that means it's an
257 // sret pointer.
258 X64ArgTypes.push_back(X64RetTy);
259 X64RetTy = VoidTy;
260 }
261}
262
263void AArch64Arm64ECCallLowering::canonicalizeThunkType(
264 Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes, raw_ostream &Out,
265 Type *&Arm64Ty, Type *&X64Ty) {
266 if (T->isFloatTy()) {
267 Out << "f";
268 Arm64Ty = T;
269 X64Ty = T;
270 return;
271 }
272
273 if (T->isDoubleTy()) {
274 Out << "d";
275 Arm64Ty = T;
276 X64Ty = T;
277 return;
278 }
279
280 if (T->isFloatingPointTy()) {
282 "Only 32 and 64 bit floating points are supported for ARM64EC thunks");
283 }
284
285 auto &DL = M->getDataLayout();
286
287 if (auto *StructTy = dyn_cast<StructType>(T))
288 if (StructTy->getNumElements() == 1)
289 T = StructTy->getElementType(0);
290
291 if (T->isArrayTy()) {
292 Type *ElementTy = T->getArrayElementType();
293 uint64_t ElementCnt = T->getArrayNumElements();
294 uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8;
295 uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes;
296 if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) {
297 Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes;
298 if (Alignment.value() >= 16 && !Ret)
299 Out << "a" << Alignment.value();
300 Arm64Ty = T;
301 if (TotalSizeBytes <= 8) {
302 // Arm64 returns small structs of float/double in float registers;
303 // X64 uses RAX.
304 X64Ty = llvm::Type::getIntNTy(M->getContext(), TotalSizeBytes * 8);
305 } else {
306 // Struct is passed directly on Arm64, but indirectly on X64.
307 X64Ty = PtrTy;
308 }
309 return;
310 } else if (T->isFloatingPointTy()) {
311 report_fatal_error("Only 32 and 64 bit floating points are supported for "
312 "ARM64EC thunks");
313 }
314 }
315
316 if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) {
317 Out << "i8";
318 Arm64Ty = I64Ty;
319 X64Ty = I64Ty;
320 return;
321 }
322
323 unsigned TypeSize = ArgSizeBytes;
324 if (TypeSize == 0)
325 TypeSize = DL.getTypeSizeInBits(T) / 8;
326 Out << "m";
327 if (TypeSize != 4)
328 Out << TypeSize;
329 if (Alignment.value() >= 16 && !Ret)
330 Out << "a" << Alignment.value();
331 // FIXME: Try to canonicalize Arm64Ty more thoroughly?
332 Arm64Ty = T;
333 if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) {
334 // Pass directly in an integer register
335 X64Ty = llvm::Type::getIntNTy(M->getContext(), TypeSize * 8);
336 } else {
337 // Passed directly on Arm64, but indirectly on X64.
338 X64Ty = PtrTy;
339 }
340}
341
342// This function builds the "exit thunk", a function which translates
343// arguments and return values when calling x64 code from AArch64 code.
344Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
345 AttributeList Attrs) {
346 SmallString<256> ExitThunkName;
347 llvm::raw_svector_ostream ExitThunkStream(ExitThunkName);
348 FunctionType *Arm64Ty, *X64Ty;
349 getThunkType(FT, Attrs, Arm64ECThunkType::Exit, ExitThunkStream, Arm64Ty,
350 X64Ty);
351 if (Function *F = M->getFunction(ExitThunkName))
352 return F;
353
355 ExitThunkName, M);
356 F->setCallingConv(CallingConv::ARM64EC_Thunk_Native);
357 F->setSection(".wowthk$aa");
358 F->setComdat(M->getOrInsertComdat(ExitThunkName));
359 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
360 F->addFnAttr("frame-pointer", "all");
361 // Only copy sret from the first argument. For C++ instance methods, clang can
362 // stick an sret marking on a later argument, but it doesn't actually affect
363 // the ABI, so we can omit it. This avoids triggering a verifier assertion.
364 if (FT->getNumParams()) {
365 auto SRet = Attrs.getParamAttr(0, Attribute::StructRet);
366 auto InReg = Attrs.getParamAttr(0, Attribute::InReg);
367 if (SRet.isValid() && !InReg.isValid())
368 F->addParamAttr(1, SRet);
369 }
370 // FIXME: Copy anything other than sret? Shouldn't be necessary for normal
371 // C ABI, but might show up in other cases.
372 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F);
373 IRBuilder<> IRB(BB);
374 Value *CalleePtr =
375 M->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy);
376 Value *Callee = IRB.CreateLoad(PtrTy, CalleePtr);
377 auto &DL = M->getDataLayout();
379
380 // Pass the called function in x9.
381 Args.push_back(F->arg_begin());
382
383 Type *RetTy = Arm64Ty->getReturnType();
384 if (RetTy != X64Ty->getReturnType()) {
385 // If the return type is an array or struct, translate it. Values of size
386 // 8 or less go into RAX; bigger values go into memory, and we pass a
387 // pointer.
388 if (DL.getTypeStoreSize(RetTy) > 8) {
389 Args.push_back(IRB.CreateAlloca(RetTy));
390 }
391 }
392
393 for (auto &Arg : make_range(F->arg_begin() + 1, F->arg_end())) {
394 // Translate arguments from AArch64 calling convention to x86 calling
395 // convention.
396 //
397 // For simple types, we don't need to do any translation: they're
398 // represented the same way. (Implicit sign extension is not part of
399 // either convention.)
400 //
401 // The big thing we have to worry about is struct types... but
402 // fortunately AArch64 clang is pretty friendly here: the cases that need
403 // translation are always passed as a struct or array. (If we run into
404 // some cases where this doesn't work, we can teach clang to mark it up
405 // with an attribute.)
406 //
407 // The first argument is the called function, stored in x9.
408 if (Arg.getType()->isArrayTy() || Arg.getType()->isStructTy() ||
409 DL.getTypeStoreSize(Arg.getType()) > 8) {
410 Value *Mem = IRB.CreateAlloca(Arg.getType());
411 IRB.CreateStore(&Arg, Mem);
412 if (DL.getTypeStoreSize(Arg.getType()) <= 8) {
413 Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType()));
414 Args.push_back(IRB.CreateLoad(IntTy, IRB.CreateBitCast(Mem, PtrTy)));
415 } else
416 Args.push_back(Mem);
417 } else {
418 Args.push_back(&Arg);
419 }
420 }
421 // FIXME: Transfer necessary attributes? sret? anything else?
422
423 Callee = IRB.CreateBitCast(Callee, PtrTy);
424 CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args);
425 Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
426
427 Value *RetVal = Call;
428 if (RetTy != X64Ty->getReturnType()) {
429 // If we rewrote the return type earlier, convert the return value to
430 // the proper type.
431 if (DL.getTypeStoreSize(RetTy) > 8) {
432 RetVal = IRB.CreateLoad(RetTy, Args[1]);
433 } else {
434 Value *CastAlloca = IRB.CreateAlloca(RetTy);
435 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
436 RetVal = IRB.CreateLoad(RetTy, CastAlloca);
437 }
438 }
439
440 if (RetTy->isVoidTy())
441 IRB.CreateRetVoid();
442 else
443 IRB.CreateRet(RetVal);
444 return F;
445}
446
447// This function builds the "entry thunk", a function which translates
448// arguments and return values when calling AArch64 code from x64 code.
449Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
450 SmallString<256> EntryThunkName;
451 llvm::raw_svector_ostream EntryThunkStream(EntryThunkName);
452 FunctionType *Arm64Ty, *X64Ty;
453 getThunkType(F->getFunctionType(), F->getAttributes(),
454 Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty);
455 if (Function *F = M->getFunction(EntryThunkName))
456 return F;
457
459 EntryThunkName, M);
460 Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
461 Thunk->setSection(".wowthk$aa");
462 Thunk->setComdat(M->getOrInsertComdat(EntryThunkName));
463 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
464 Thunk->addFnAttr("frame-pointer", "all");
465
466 auto &DL = M->getDataLayout();
467 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk);
468 IRBuilder<> IRB(BB);
469
470 Type *RetTy = Arm64Ty->getReturnType();
471 Type *X64RetType = X64Ty->getReturnType();
472
473 bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy();
474 unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
475 unsigned PassthroughArgSize = F->isVarArg() ? 5 : Thunk->arg_size();
476
477 // Translate arguments to call.
479 for (unsigned i = ThunkArgOffset, e = PassthroughArgSize; i != e; ++i) {
480 Value *Arg = Thunk->getArg(i);
481 Type *ArgTy = Arm64Ty->getParamType(i - ThunkArgOffset);
482 if (ArgTy->isArrayTy() || ArgTy->isStructTy() ||
483 DL.getTypeStoreSize(ArgTy) > 8) {
484 // Translate array/struct arguments to the expected type.
485 if (DL.getTypeStoreSize(ArgTy) <= 8) {
486 Value *CastAlloca = IRB.CreateAlloca(ArgTy);
487 IRB.CreateStore(Arg, IRB.CreateBitCast(CastAlloca, PtrTy));
488 Arg = IRB.CreateLoad(ArgTy, CastAlloca);
489 } else {
490 Arg = IRB.CreateLoad(ArgTy, IRB.CreateBitCast(Arg, PtrTy));
491 }
492 }
493 Args.push_back(Arg);
494 }
495
496 if (F->isVarArg()) {
497 // The 5th argument to variadic entry thunks is used to model the x64 sp
498 // which is passed to the thunk in x4, this can be passed to the callee as
499 // the variadic argument start address after skipping over the 32 byte
500 // shadow store.
501
502 // The EC thunk CC will assign any argument marked as InReg to x4.
503 Thunk->addParamAttr(5, Attribute::InReg);
504 Value *Arg = Thunk->getArg(5);
505 Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20));
506 Args.push_back(Arg);
507
508 // Pass in a zero variadic argument size (in x5).
509 Args.push_back(IRB.getInt64(0));
510 }
511
512 // Call the function passed to the thunk.
513 Value *Callee = Thunk->getArg(0);
514 Callee = IRB.CreateBitCast(Callee, PtrTy);
515 Value *Call = IRB.CreateCall(Arm64Ty, Callee, Args);
516
517 Value *RetVal = Call;
518 if (TransformDirectToSRet) {
519 IRB.CreateStore(RetVal, IRB.CreateBitCast(Thunk->getArg(1), PtrTy));
520 } else if (X64RetType != RetTy) {
521 Value *CastAlloca = IRB.CreateAlloca(X64RetType);
522 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
523 RetVal = IRB.CreateLoad(X64RetType, CastAlloca);
524 }
525
526 // Return to the caller. Note that the isel has code to translate this
527 // "ret" to a tail call to __os_arm64x_dispatch_ret. (Alternatively, we
528 // could emit a tail call here, but that would require a dedicated calling
529 // convention, which seems more complicated overall.)
530 if (X64RetType->isVoidTy())
531 IRB.CreateRetVoid();
532 else
533 IRB.CreateRet(RetVal);
534
535 return Thunk;
536}
537
538// Builds the "guest exit thunk", a helper to call a function which may or may
539// not be an exit thunk. (We optimistically assume non-dllimport function
540// declarations refer to functions defined in AArch64 code; if the linker
541// can't prove that, we use this routine instead.)
542Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
543 llvm::raw_null_ostream NullThunkName;
544 FunctionType *Arm64Ty, *X64Ty;
545 getThunkType(F->getFunctionType(), F->getAttributes(),
546 Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty);
547 auto MangledName = getArm64ECMangledFunctionName(F->getName().str());
548 assert(MangledName && "Can't guest exit to function that's already native");
549 std::string ThunkName = *MangledName;
550 if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
551 ThunkName.insert(ThunkName.find("@"), "$exit_thunk");
552 } else {
553 ThunkName.append("$exit_thunk");
554 }
556 Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
557 GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
558 GuestExit->setSection(".wowthk$aa");
559 GuestExit->setMetadata(
560 "arm64ec_unmangled_name",
561 MDNode::get(M->getContext(),
562 MDString::get(M->getContext(), F->getName())));
563 GuestExit->setMetadata(
564 "arm64ec_ecmangled_name",
565 MDNode::get(M->getContext(),
566 MDString::get(M->getContext(), *MangledName)));
567 F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {}));
568 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
569 IRBuilder<> B(BB);
570
571 // Load the global symbol as a pointer to the check function.
572 Value *GuardFn;
573 if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf"))
574 GuardFn = GuardFnCFGlobal;
575 else
576 GuardFn = GuardFnGlobal;
577 LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
578
579 // Create new call instruction. The CFGuard check should always be a call,
580 // even if the original CallBase is an Invoke or CallBr instruction.
581 Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes());
582 CallInst *GuardCheck = B.CreateCall(
583 GuardFnType, GuardCheckLoad,
584 {B.CreateBitCast(F, B.getPtrTy()), B.CreateBitCast(Thunk, B.getPtrTy())});
585
586 // Ensure that the first argument is passed in the correct register.
588
589 Value *GuardRetVal = B.CreateBitCast(GuardCheck, PtrTy);
591 for (Argument &Arg : GuestExit->args())
592 Args.push_back(&Arg);
593 CallInst *Call = B.CreateCall(Arm64Ty, GuardRetVal, Args);
594 Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
595
596 if (Call->getType()->isVoidTy())
597 B.CreateRetVoid();
598 else
599 B.CreateRet(Call);
600
601 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
602 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
603 if (SRetAttr.isValid() && !InRegAttr.isValid()) {
604 GuestExit->addParamAttr(0, SRetAttr);
605 Call->addParamAttr(0, SRetAttr);
606 }
607
608 return GuestExit;
609}
610
611// Lower an indirect call with inline code.
612void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
614 "Only applicable for Windows targets");
615
616 IRBuilder<> B(CB);
617 Value *CalledOperand = CB->getCalledOperand();
618
619 // If the indirect call is called within catchpad or cleanuppad,
620 // we need to copy "funclet" bundle of the call.
622 if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet))
623 Bundles.push_back(OperandBundleDef(*Bundle));
624
625 // Load the global symbol as a pointer to the check function.
626 Value *GuardFn;
627 if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf"))
628 GuardFn = GuardFnCFGlobal;
629 else
630 GuardFn = GuardFnGlobal;
631 LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
632
633 // Create new call instruction. The CFGuard check should always be a call,
634 // even if the original CallBase is an Invoke or CallBr instruction.
635 Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes());
636 CallInst *GuardCheck =
637 B.CreateCall(GuardFnType, GuardCheckLoad,
638 {B.CreateBitCast(CalledOperand, B.getPtrTy()),
639 B.CreateBitCast(Thunk, B.getPtrTy())},
640 Bundles);
641
642 // Ensure that the first argument is passed in the correct register.
644
645 Value *GuardRetVal = B.CreateBitCast(GuardCheck, CalledOperand->getType());
646 CB->setCalledOperand(GuardRetVal);
647}
648
649bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
650 if (!GenerateThunks)
651 return false;
652
653 M = &Mod;
654
655 // Check if this module has the cfguard flag and read its value.
656 if (auto *MD =
657 mdconst::extract_or_null<ConstantInt>(M->getModuleFlag("cfguard")))
658 cfguard_module_flag = MD->getZExtValue();
659
660 PtrTy = PointerType::getUnqual(M->getContext());
661 I64Ty = Type::getInt64Ty(M->getContext());
662 VoidTy = Type::getVoidTy(M->getContext());
663
664 GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
665 GuardFnPtrType = PointerType::get(GuardFnType, 0);
666 GuardFnCFGlobal =
667 M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType);
668 GuardFnGlobal =
669 M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType);
670
671 SetVector<Function *> DirectCalledFns;
672 for (Function &F : Mod)
673 if (!F.isDeclaration() &&
674 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
675 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)
676 processFunction(F, DirectCalledFns);
677
678 struct ThunkInfo {
679 Constant *Src;
680 Constant *Dst;
682 };
683 SmallVector<ThunkInfo> ThunkMapping;
684 for (Function &F : Mod) {
685 if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) &&
686 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
687 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) {
688 if (!F.hasComdat())
689 F.setComdat(Mod.getOrInsertComdat(F.getName()));
690 ThunkMapping.push_back(
691 {&F, buildEntryThunk(&F), Arm64ECThunkType::Entry});
692 }
693 }
694 for (Function *F : DirectCalledFns) {
695 ThunkMapping.push_back(
696 {F, buildExitThunk(F->getFunctionType(), F->getAttributes()),
697 Arm64ECThunkType::Exit});
698 if (!F->hasDLLImportStorageClass())
699 ThunkMapping.push_back(
700 {buildGuestExitThunk(F), F, Arm64ECThunkType::GuestExit});
701 }
702
703 if (!ThunkMapping.empty()) {
704 SmallVector<Constant *> ThunkMappingArrayElems;
705 for (ThunkInfo &Thunk : ThunkMapping) {
706 ThunkMappingArrayElems.push_back(ConstantStruct::getAnon(
707 {ConstantExpr::getBitCast(Thunk.Src, PtrTy),
709 ConstantInt::get(M->getContext(), APInt(32, uint8_t(Thunk.Kind)))}));
710 }
711 Constant *ThunkMappingArray = ConstantArray::get(
712 llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(),
713 ThunkMappingArrayElems.size()),
714 ThunkMappingArrayElems);
715 new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false,
716 GlobalValue::ExternalLinkage, ThunkMappingArray,
717 "llvm.arm64ec.symbolmap");
718 }
719
720 return true;
721}
722
723bool AArch64Arm64ECCallLowering::processFunction(
724 Function &F, SetVector<Function *> &DirectCalledFns) {
725 SmallVector<CallBase *, 8> IndirectCalls;
726
727 // For ARM64EC targets, a function definition's name is mangled differently
728 // from the normal symbol. We currently have no representation of this sort
729 // of symbol in IR, so we change the name to the mangled name, then store
730 // the unmangled name as metadata. Later passes that need the unmangled
731 // name (emitting the definition) can grab it from the metadata.
732 //
733 // FIXME: Handle functions with weak linkage?
734 if (!F.hasLocalLinkage() || F.hasAddressTaken()) {
735 if (std::optional<std::string> MangledName =
736 getArm64ECMangledFunctionName(F.getName().str())) {
737 F.setMetadata("arm64ec_unmangled_name",
738 MDNode::get(M->getContext(),
739 MDString::get(M->getContext(), F.getName())));
740 if (F.hasComdat() && F.getComdat()->getName() == F.getName()) {
741 Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value());
742 SmallVector<GlobalObject *> ComdatUsers =
743 to_vector(F.getComdat()->getUsers());
744 for (GlobalObject *User : ComdatUsers)
745 User->setComdat(MangledComdat);
746 }
747 F.setName(MangledName.value());
748 }
749 }
750
751 // Iterate over the instructions to find all indirect call/invoke/callbr
752 // instructions. Make a separate list of pointers to indirect
753 // call/invoke/callbr instructions because the original instructions will be
754 // deleted as the checks are added.
755 for (BasicBlock &BB : F) {
756 for (Instruction &I : BB) {
757 auto *CB = dyn_cast<CallBase>(&I);
759 CB->isInlineAsm())
760 continue;
761
762 // We need to instrument any call that isn't directly calling an
763 // ARM64 function.
764 //
765 // FIXME: getCalledFunction() fails if there's a bitcast (e.g.
766 // unprototyped functions in C)
767 if (Function *F = CB->getCalledFunction()) {
768 if (!LowerDirectToIndirect || F->hasLocalLinkage() ||
769 F->isIntrinsic() || !F->isDeclaration())
770 continue;
771
772 DirectCalledFns.insert(F);
773 continue;
774 }
775
776 IndirectCalls.push_back(CB);
777 ++Arm64ECCallsLowered;
778 }
779 }
780
781 if (IndirectCalls.empty())
782 return false;
783
784 for (CallBase *CB : IndirectCalls)
785 lowerCall(CB);
786
787 return true;
788}
789
790char AArch64Arm64ECCallLowering::ID = 0;
791INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering",
792 "AArch64Arm64ECCallLowering", false, false)
793
795 return new AArch64Arm64ECCallLowering;
796}
static cl::opt< bool > LowerDirectToIndirect("arm64ec-lower-direct-to-indirect", cl::Hidden, cl::init(true))
static cl::opt< bool > GenerateThunks("arm64ec-generate-thunks", cl::Hidden, cl::init(true))
OperandBundleDefT< Value * > OperandBundleDef
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
return RetTy
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
Module * Mod
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:38
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallString class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:40
Class for arbitrary precision integers.
Definition: APInt.h:76
This class represents an incoming formal argument to a Function.
Definition: Argument.h:31
static ArrayType * get(Type *ElementType, uint64_t NumElements)
This static method is the primary way to construct an ArrayType.
Definition: Type.cpp:647
Attribute getParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) const
Return the attribute object that exists at the arg index.
Definition: Attributes.h:837
MaybeAlign getParamAlignment(unsigned ArgNo) const
Return the alignment for the specified function parameter.
LLVM Basic Block Representation.
Definition: BasicBlock.h:60
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:199
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Definition: InstrTypes.h:1494
bool isInlineAsm() const
Check if this call is an inline asm statement.
Definition: InstrTypes.h:1809
void setCallingConv(CallingConv::ID CC)
Definition: InstrTypes.h:1804
std::optional< OperandBundleUse > getOperandBundle(StringRef Name) const
Return an operand bundle by name, if present.
Definition: InstrTypes.h:2405
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1742
bool hasFnAttr(Attribute::AttrKind Kind) const
Determine whether this call has the given attribute.
Definition: InstrTypes.h:1828
CallingConv::ID getCallingConv() const
Definition: InstrTypes.h:1800
Value * getCalledOperand() const
Definition: InstrTypes.h:1735
FunctionType * getFunctionType() const
Definition: InstrTypes.h:1600
void setCalledOperand(Value *V)
Definition: InstrTypes.h:1778
AttributeList getAttributes() const
Return the parameter attributes for this call.
Definition: InstrTypes.h:1819
This class represents a function call, abstracting a target machine's calling convention.
static Constant * get(ArrayType *T, ArrayRef< Constant * > V)
Definition: Constants.cpp:1291
static Constant * getBitCast(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2140
static Constant * getAnon(ArrayRef< Constant * > V, bool Packed=false)
Return an anonymous struct that has the specified elements.
Definition: Constants.h:476
This is an important base class in LLVM.
Definition: Constant.h:41
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
Definition: Function.h:164
@ WeakODRLinkage
Same, but only replaced by something equivalent.
Definition: GlobalValue.h:57
@ ExternalLinkage
Externally visible function.
Definition: GlobalValue.h:52
@ LinkOnceODRLinkage
Same, but only replaced by something equivalent.
Definition: GlobalValue.h:55
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2666
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
Definition: Instruction.cpp:83
An instruction for reading from memory.
Definition: Instructions.h:184
static MDTuple * get(LLVMContext &Context, ArrayRef< Metadata * > MDs)
Definition: Metadata.h:1541
static MDString * get(LLVMContext &Context, StringRef Str)
Definition: Metadata.cpp:600
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:251
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
const std::string & getTargetTriple() const
Get the target triple which is a string describing the target host.
Definition: Module.h:297
Comdat * getOrInsertComdat(StringRef Name)
Return the Comdat in the module with the specified name.
Definition: Module.cpp:589
A container for an operand bundle being viewed as a set of values rather than a set of uses.
Definition: InstrTypes.h:1447
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
A vector that has set insertion semantics.
Definition: SetVector.h:57
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
SmallString - A SmallString is just a SmallVector with methods and accessors that make it work better...
Definition: SmallString.h:26
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:586
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
Triple - Helper class for working with autoconf configuration names.
Definition: Triple.h:44
bool isOSWindows() const
Tests whether the OS is Windows.
Definition: Triple.h:619
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isArrayTy() const
True if this is an instance of ArrayType.
Definition: Type.h:252
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:255
bool isFloatTy() const
Return true if this is 'float', a 32-bit IEEE fp type.
Definition: Type.h:154
static IntegerType * getIntNTy(LLVMContext &C, unsigned N)
static Type * getVoidTy(LLVMContext &C)
bool isStructTy() const
True if this is an instance of StructType.
Definition: Type.h:249
bool isDoubleTy() const
Return true if this is 'double', a 64-bit IEEE fp type.
Definition: Type.h:157
static IntegerType * getInt64Ty(LLVMContext &C)
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:140
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
A raw_ostream that discards all output.
Definition: raw_ostream.h:723
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
A raw_ostream that writes to an SmallVector or SmallString.
Definition: raw_ostream.h:690
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
constexpr char Attrs[]
Key for Kernel::Metadata::mAttrs.
Arm64ECThunkType
Definition: COFF.h:809
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ ARM64EC_Thunk_Native
Calling convention used in the ARM64EC ABI to implement calls between ARM64 code and thunks.
Definition: CallingConv.h:265
@ CFGuard_Check
Special calling convention on Windows for calling the Control Guard Check ICall funtion.
Definition: CallingConv.h:82
@ ARM64EC_Thunk_X64
Calling convention used in the ARM64EC ABI to implement calls between x64 code and thunks.
Definition: CallingConv.h:260
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:450
constexpr double e
Definition: MathExtras.h:31
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
std::optional< std::string > getArm64ECMangledFunctionName(StringRef Name)
Definition: Mangler.cpp:293
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
void initializeAArch64Arm64ECCallLoweringPass(PassRegistry &)
ModulePass * createAArch64Arm64ECCallLoweringPass()
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Definition: Error.cpp:156
SmallVector< ValueTypeFromRangeType< R >, Size > to_vector(R &&Range)
Given a range of type R, iterate the entire range and return a SmallVector with elements of the vecto...
Definition: SmallVector.h:1312
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
uint64_t value() const
This is a hole in the type system and should not be abused.
Definition: Alignment.h:85
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.
Definition: Alignment.h:141