1 //===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file contains the IR transform to lower external or indirect calls for
11 /// the ARM64EC calling convention. Such calls must go through the runtime, so
12 /// we can translate the calling convention for calls into the emulator.
13 ///
14 /// This subsumes Control Flow Guard handling.
15 ///
16 //===----------------------------------------------------------------------===//
17
18 #include "AArch64.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/IR/CallingConv.h"
24 #include "llvm/IR/GlobalAlias.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Mangler.h"
28 #include "llvm/IR/Module.h"
29 #include "llvm/Object/COFF.h"
30 #include "llvm/Pass.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/TargetParser/Triple.h"
33
34 using namespace llvm;
35 using namespace llvm::COFF;
36
37 using OperandBundleDef = OperandBundleDefT<Value *>;
38
39 #define DEBUG_TYPE "arm64eccalllowering"
40
41 STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered");
42
43 static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect",
44 cl::Hidden, cl::init(true));
45 static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden,
46 cl::init(true));
47
48 namespace {
49
50 enum ThunkArgTranslation : uint8_t {
51 Direct,
52 Bitcast,
53 PointerIndirection,
54 };
55
56 struct ThunkArgInfo {
57 Type *Arm64Ty;
58 Type *X64Ty;
59 ThunkArgTranslation Translation;
60 };
61
62 class AArch64Arm64ECCallLowering : public ModulePass {
63 public:
64 static char ID;
AArch64Arm64ECCallLowering()65 AArch64Arm64ECCallLowering() : ModulePass(ID) {}
66
67 Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs);
68 Function *buildEntryThunk(Function *F);
69 void lowerCall(CallBase *CB);
70 Function *buildGuestExitThunk(Function *F);
71 Function *buildPatchableThunk(GlobalAlias *UnmangledAlias,
72 GlobalAlias *MangledAlias);
73 bool processFunction(Function &F, SetVector<GlobalValue *> &DirectCalledFns,
74 DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap);
75 bool runOnModule(Module &M) override;
76
77 private:
78 int cfguard_module_flag = 0;
79 FunctionType *GuardFnType = nullptr;
80 FunctionType *DispatchFnType = nullptr;
81 Constant *GuardFnCFGlobal = nullptr;
82 Constant *GuardFnGlobal = nullptr;
83 Constant *DispatchFnGlobal = nullptr;
84 Module *M = nullptr;
85
86 Type *PtrTy;
87 Type *I64Ty;
88 Type *VoidTy;
89
90 void getThunkType(FunctionType *FT, AttributeList AttrList,
91 Arm64ECThunkType TT, raw_ostream &Out,
92 FunctionType *&Arm64Ty, FunctionType *&X64Ty,
93 SmallVector<ThunkArgTranslation> &ArgTranslations);
94 void getThunkRetType(FunctionType *FT, AttributeList AttrList,
95 raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy,
96 SmallVectorImpl<Type *> &Arm64ArgTypes,
97 SmallVectorImpl<Type *> &X64ArgTypes,
98 SmallVector<ThunkArgTranslation> &ArgTranslations,
99 bool &HasSretPtr);
100 void getThunkArgTypes(FunctionType *FT, AttributeList AttrList,
101 Arm64ECThunkType TT, raw_ostream &Out,
102 SmallVectorImpl<Type *> &Arm64ArgTypes,
103 SmallVectorImpl<Type *> &X64ArgTypes,
104 SmallVectorImpl<ThunkArgTranslation> &ArgTranslations,
105 bool HasSretPtr);
106 ThunkArgInfo canonicalizeThunkType(Type *T, Align Alignment, bool Ret,
107 uint64_t ArgSizeBytes, raw_ostream &Out);
108 };
109
110 } // end anonymous namespace
111
getThunkType(FunctionType * FT,AttributeList AttrList,Arm64ECThunkType TT,raw_ostream & Out,FunctionType * & Arm64Ty,FunctionType * & X64Ty,SmallVector<ThunkArgTranslation> & ArgTranslations)112 void AArch64Arm64ECCallLowering::getThunkType(
113 FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
114 raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty,
115 SmallVector<ThunkArgTranslation> &ArgTranslations) {
116 Out << (TT == Arm64ECThunkType::Entry ? "$ientry_thunk$cdecl$"
117 : "$iexit_thunk$cdecl$");
118
119 Type *Arm64RetTy;
120 Type *X64RetTy;
121
122 SmallVector<Type *> Arm64ArgTypes;
123 SmallVector<Type *> X64ArgTypes;
124
125 // The first argument to a thunk is the called function, stored in x9.
126 // For exit thunks, we pass the called function down to the emulator;
127 // for entry/guest exit thunks, we just call the Arm64 function directly.
128 if (TT == Arm64ECThunkType::Exit)
129 Arm64ArgTypes.push_back(PtrTy);
130 X64ArgTypes.push_back(PtrTy);
131
132 bool HasSretPtr = false;
133 getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
134 X64ArgTypes, ArgTranslations, HasSretPtr);
135
136 getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,
137 ArgTranslations, HasSretPtr);
138
139 Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false);
140
141 X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false);
142 }
143
getThunkArgTypes(FunctionType * FT,AttributeList AttrList,Arm64ECThunkType TT,raw_ostream & Out,SmallVectorImpl<Type * > & Arm64ArgTypes,SmallVectorImpl<Type * > & X64ArgTypes,SmallVectorImpl<ThunkArgTranslation> & ArgTranslations,bool HasSretPtr)144 void AArch64Arm64ECCallLowering::getThunkArgTypes(
145 FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
146 raw_ostream &Out, SmallVectorImpl<Type *> &Arm64ArgTypes,
147 SmallVectorImpl<Type *> &X64ArgTypes,
148 SmallVectorImpl<ThunkArgTranslation> &ArgTranslations, bool HasSretPtr) {
149
150 Out << "$";
151 if (FT->isVarArg()) {
152 // We treat the variadic function's thunk as a normal function
153 // with the following type on the ARM side:
154 // rettype exitthunk(
155 // ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5)
156 //
157 // that can coverage all types of variadic function.
158 // x9 is similar to normal exit thunk, store the called function.
159 // x0-x3 is the arguments be stored in registers.
160 // x4 is the address of the arguments on the stack.
161 // x5 is the size of the arguments on the stack.
162 //
163 // On the x64 side, it's the same except that x5 isn't set.
164 //
165 // If both the ARM and X64 sides are sret, there are only three
166 // arguments in registers.
167 //
168 // If the X64 side is sret, but the ARM side isn't, we pass an extra value
169 // to/from the X64 side, and let SelectionDAG transform it into a memory
170 // location.
171 Out << "varargs";
172
173 // x0-x3
174 for (int i = HasSretPtr ? 1 : 0; i < 4; i++) {
175 Arm64ArgTypes.push_back(I64Ty);
176 X64ArgTypes.push_back(I64Ty);
177 ArgTranslations.push_back(ThunkArgTranslation::Direct);
178 }
179
180 // x4
181 Arm64ArgTypes.push_back(PtrTy);
182 X64ArgTypes.push_back(PtrTy);
183 ArgTranslations.push_back(ThunkArgTranslation::Direct);
184 // x5
185 Arm64ArgTypes.push_back(I64Ty);
186 if (TT != Arm64ECThunkType::Entry) {
187 // FIXME: x5 isn't actually used by the x64 side; revisit once we
188 // have proper isel for varargs
189 X64ArgTypes.push_back(I64Ty);
190 ArgTranslations.push_back(ThunkArgTranslation::Direct);
191 }
192 return;
193 }
194
195 unsigned I = 0;
196 if (HasSretPtr)
197 I++;
198
199 if (I == FT->getNumParams()) {
200 Out << "v";
201 return;
202 }
203
204 for (unsigned E = FT->getNumParams(); I != E; ++I) {
205 #if 0
206 // FIXME: Need more information about argument size; see
207 // https://reviews.llvm.org/D132926
208 uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I);
209 Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne();
210 #else
211 uint64_t ArgSizeBytes = 0;
212 Align ParamAlign = Align();
213 #endif
214 auto [Arm64Ty, X64Ty, ArgTranslation] =
215 canonicalizeThunkType(FT->getParamType(I), ParamAlign,
216 /*Ret*/ false, ArgSizeBytes, Out);
217 Arm64ArgTypes.push_back(Arm64Ty);
218 X64ArgTypes.push_back(X64Ty);
219 ArgTranslations.push_back(ArgTranslation);
220 }
221 }
222
getThunkRetType(FunctionType * FT,AttributeList AttrList,raw_ostream & Out,Type * & Arm64RetTy,Type * & X64RetTy,SmallVectorImpl<Type * > & Arm64ArgTypes,SmallVectorImpl<Type * > & X64ArgTypes,SmallVector<ThunkArgTranslation> & ArgTranslations,bool & HasSretPtr)223 void AArch64Arm64ECCallLowering::getThunkRetType(
224 FunctionType *FT, AttributeList AttrList, raw_ostream &Out,
225 Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes,
226 SmallVectorImpl<Type *> &X64ArgTypes,
227 SmallVector<ThunkArgTranslation> &ArgTranslations, bool &HasSretPtr) {
228 Type *T = FT->getReturnType();
229 #if 0
230 // FIXME: Need more information about argument size; see
231 // https://reviews.llvm.org/D132926
232 uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes();
233 #else
234 int64_t ArgSizeBytes = 0;
235 #endif
236 if (T->isVoidTy()) {
237 if (FT->getNumParams()) {
238 Attribute SRetAttr0 = AttrList.getParamAttr(0, Attribute::StructRet);
239 Attribute InRegAttr0 = AttrList.getParamAttr(0, Attribute::InReg);
240 Attribute SRetAttr1, InRegAttr1;
241 if (FT->getNumParams() > 1) {
242 // Also check the second parameter (for class methods, the first
243 // parameter is "this", and the second parameter is the sret pointer.)
244 // It doesn't matter which one is sret.
245 SRetAttr1 = AttrList.getParamAttr(1, Attribute::StructRet);
246 InRegAttr1 = AttrList.getParamAttr(1, Attribute::InReg);
247 }
248 if ((SRetAttr0.isValid() && InRegAttr0.isValid()) ||
249 (SRetAttr1.isValid() && InRegAttr1.isValid())) {
250 // sret+inreg indicates a call that returns a C++ class value. This is
251 // actually equivalent to just passing and returning a void* pointer
252 // as the first or second argument. Translate it that way, instead of
253 // trying to model "inreg" in the thunk's calling convention; this
254 // simplfies the rest of the code, and matches MSVC mangling.
255 Out << "i8";
256 Arm64RetTy = I64Ty;
257 X64RetTy = I64Ty;
258 return;
259 }
260 if (SRetAttr0.isValid()) {
261 // FIXME: Sanity-check the sret type; if it's an integer or pointer,
262 // we'll get screwy mangling/codegen.
263 // FIXME: For large struct types, mangle as an integer argument and
264 // integer return, so we can reuse more thunks, instead of "m" syntax.
265 // (MSVC mangles this case as an integer return with no argument, but
266 // that's a miscompile.)
267 Type *SRetType = SRetAttr0.getValueAsType();
268 Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne();
269 canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes,
270 Out);
271 Arm64RetTy = VoidTy;
272 X64RetTy = VoidTy;
273 Arm64ArgTypes.push_back(FT->getParamType(0));
274 X64ArgTypes.push_back(FT->getParamType(0));
275 ArgTranslations.push_back(ThunkArgTranslation::Direct);
276 HasSretPtr = true;
277 return;
278 }
279 }
280
281 Out << "v";
282 Arm64RetTy = VoidTy;
283 X64RetTy = VoidTy;
284 return;
285 }
286
287 auto info =
288 canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out);
289 Arm64RetTy = info.Arm64Ty;
290 X64RetTy = info.X64Ty;
291 if (X64RetTy->isPointerTy()) {
292 // If the X64 type is canonicalized to a pointer, that means it's
293 // passed/returned indirectly. For a return value, that means it's an
294 // sret pointer.
295 X64ArgTypes.push_back(X64RetTy);
296 X64RetTy = VoidTy;
297 }
298 }
299
canonicalizeThunkType(Type * T,Align Alignment,bool Ret,uint64_t ArgSizeBytes,raw_ostream & Out)300 ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType(
301 Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes,
302 raw_ostream &Out) {
303
304 auto direct = [](Type *T) {
305 return ThunkArgInfo{T, T, ThunkArgTranslation::Direct};
306 };
307
308 auto bitcast = [this](Type *Arm64Ty, uint64_t SizeInBytes) {
309 return ThunkArgInfo{Arm64Ty,
310 llvm::Type::getIntNTy(M->getContext(), SizeInBytes * 8),
311 ThunkArgTranslation::Bitcast};
312 };
313
314 auto pointerIndirection = [this](Type *Arm64Ty) {
315 return ThunkArgInfo{Arm64Ty, PtrTy,
316 ThunkArgTranslation::PointerIndirection};
317 };
318
319 if (T->isFloatTy()) {
320 Out << "f";
321 return direct(T);
322 }
323
324 if (T->isDoubleTy()) {
325 Out << "d";
326 return direct(T);
327 }
328
329 if (T->isFloatingPointTy()) {
330 report_fatal_error(
331 "Only 32 and 64 bit floating points are supported for ARM64EC thunks");
332 }
333
334 auto &DL = M->getDataLayout();
335
336 if (auto *StructTy = dyn_cast<StructType>(T))
337 if (StructTy->getNumElements() == 1)
338 T = StructTy->getElementType(0);
339
340 if (T->isArrayTy()) {
341 Type *ElementTy = T->getArrayElementType();
342 uint64_t ElementCnt = T->getArrayNumElements();
343 uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8;
344 uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes;
345 if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) {
346 Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes;
347 if (Alignment.value() >= 16 && !Ret)
348 Out << "a" << Alignment.value();
349 if (TotalSizeBytes <= 8) {
350 // Arm64 returns small structs of float/double in float registers;
351 // X64 uses RAX.
352 return bitcast(T, TotalSizeBytes);
353 } else {
354 // Struct is passed directly on Arm64, but indirectly on X64.
355 return pointerIndirection(T);
356 }
357 } else if (T->isFloatingPointTy()) {
358 report_fatal_error("Only 32 and 64 bit floating points are supported for "
359 "ARM64EC thunks");
360 }
361 }
362
363 if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) {
364 Out << "i8";
365 return direct(I64Ty);
366 }
367
368 unsigned TypeSize = ArgSizeBytes;
369 if (TypeSize == 0)
370 TypeSize = DL.getTypeSizeInBits(T) / 8;
371 Out << "m";
372 if (TypeSize != 4)
373 Out << TypeSize;
374 if (Alignment.value() >= 16 && !Ret)
375 Out << "a" << Alignment.value();
376 // FIXME: Try to canonicalize Arm64Ty more thoroughly?
377 if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) {
378 // Pass directly in an integer register
379 return bitcast(T, TypeSize);
380 } else {
381 // Passed directly on Arm64, but indirectly on X64.
382 return pointerIndirection(T);
383 }
384 }
385
386 // This function builds the "exit thunk", a function which translates
387 // arguments and return values when calling x64 code from AArch64 code.
buildExitThunk(FunctionType * FT,AttributeList Attrs)388 Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
389 AttributeList Attrs) {
390 SmallString<256> ExitThunkName;
391 llvm::raw_svector_ostream ExitThunkStream(ExitThunkName);
392 FunctionType *Arm64Ty, *X64Ty;
393 SmallVector<ThunkArgTranslation> ArgTranslations;
394 getThunkType(FT, Attrs, Arm64ECThunkType::Exit, ExitThunkStream, Arm64Ty,
395 X64Ty, ArgTranslations);
396 if (Function *F = M->getFunction(ExitThunkName))
397 return F;
398
399 Function *F = Function::Create(Arm64Ty, GlobalValue::LinkOnceODRLinkage, 0,
400 ExitThunkName, M);
401 F->setCallingConv(CallingConv::ARM64EC_Thunk_Native);
402 F->setSection(".wowthk$aa");
403 F->setComdat(M->getOrInsertComdat(ExitThunkName));
404 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
405 F->addFnAttr("frame-pointer", "all");
406 // Only copy sret from the first argument. For C++ instance methods, clang can
407 // stick an sret marking on a later argument, but it doesn't actually affect
408 // the ABI, so we can omit it. This avoids triggering a verifier assertion.
409 if (FT->getNumParams()) {
410 auto SRet = Attrs.getParamAttr(0, Attribute::StructRet);
411 auto InReg = Attrs.getParamAttr(0, Attribute::InReg);
412 if (SRet.isValid() && !InReg.isValid())
413 F->addParamAttr(1, SRet);
414 }
415 // FIXME: Copy anything other than sret? Shouldn't be necessary for normal
416 // C ABI, but might show up in other cases.
417 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F);
418 IRBuilder<> IRB(BB);
419 Value *CalleePtr =
420 M->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy);
421 Value *Callee = IRB.CreateLoad(PtrTy, CalleePtr);
422 auto &DL = M->getDataLayout();
423 SmallVector<Value *> Args;
424
425 // Pass the called function in x9.
426 auto X64TyOffset = 1;
427 Args.push_back(F->arg_begin());
428
429 Type *RetTy = Arm64Ty->getReturnType();
430 if (RetTy != X64Ty->getReturnType()) {
431 // If the return type is an array or struct, translate it. Values of size
432 // 8 or less go into RAX; bigger values go into memory, and we pass a
433 // pointer.
434 if (DL.getTypeStoreSize(RetTy) > 8) {
435 Args.push_back(IRB.CreateAlloca(RetTy));
436 X64TyOffset++;
437 }
438 }
439
440 for (auto [Arg, X64ArgType, ArgTranslation] : llvm::zip_equal(
441 make_range(F->arg_begin() + 1, F->arg_end()),
442 make_range(X64Ty->param_begin() + X64TyOffset, X64Ty->param_end()),
443 ArgTranslations)) {
444 // Translate arguments from AArch64 calling convention to x86 calling
445 // convention.
446 //
447 // For simple types, we don't need to do any translation: they're
448 // represented the same way. (Implicit sign extension is not part of
449 // either convention.)
450 //
451 // The big thing we have to worry about is struct types... but
452 // fortunately AArch64 clang is pretty friendly here: the cases that need
453 // translation are always passed as a struct or array. (If we run into
454 // some cases where this doesn't work, we can teach clang to mark it up
455 // with an attribute.)
456 //
457 // The first argument is the called function, stored in x9.
458 if (ArgTranslation != ThunkArgTranslation::Direct) {
459 Value *Mem = IRB.CreateAlloca(Arg.getType());
460 IRB.CreateStore(&Arg, Mem);
461 if (ArgTranslation == ThunkArgTranslation::Bitcast) {
462 Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType()));
463 Args.push_back(IRB.CreateLoad(IntTy, Mem));
464 } else {
465 assert(ArgTranslation == ThunkArgTranslation::PointerIndirection);
466 Args.push_back(Mem);
467 }
468 } else {
469 Args.push_back(&Arg);
470 }
471 assert(Args.back()->getType() == X64ArgType);
472 }
473 // FIXME: Transfer necessary attributes? sret? anything else?
474
475 CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args);
476 Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
477
478 Value *RetVal = Call;
479 if (RetTy != X64Ty->getReturnType()) {
480 // If we rewrote the return type earlier, convert the return value to
481 // the proper type.
482 if (DL.getTypeStoreSize(RetTy) > 8) {
483 RetVal = IRB.CreateLoad(RetTy, Args[1]);
484 } else {
485 Value *CastAlloca = IRB.CreateAlloca(RetTy);
486 IRB.CreateStore(Call, CastAlloca);
487 RetVal = IRB.CreateLoad(RetTy, CastAlloca);
488 }
489 }
490
491 if (RetTy->isVoidTy())
492 IRB.CreateRetVoid();
493 else
494 IRB.CreateRet(RetVal);
495 return F;
496 }
497
498 // This function builds the "entry thunk", a function which translates
499 // arguments and return values when calling AArch64 code from x64 code.
buildEntryThunk(Function * F)500 Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
501 SmallString<256> EntryThunkName;
502 llvm::raw_svector_ostream EntryThunkStream(EntryThunkName);
503 FunctionType *Arm64Ty, *X64Ty;
504 SmallVector<ThunkArgTranslation> ArgTranslations;
505 getThunkType(F->getFunctionType(), F->getAttributes(),
506 Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty,
507 ArgTranslations);
508 if (Function *F = M->getFunction(EntryThunkName))
509 return F;
510
511 Function *Thunk = Function::Create(X64Ty, GlobalValue::LinkOnceODRLinkage, 0,
512 EntryThunkName, M);
513 Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
514 Thunk->setSection(".wowthk$aa");
515 Thunk->setComdat(M->getOrInsertComdat(EntryThunkName));
516 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
517 Thunk->addFnAttr("frame-pointer", "all");
518
519 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk);
520 IRBuilder<> IRB(BB);
521
522 Type *RetTy = Arm64Ty->getReturnType();
523 Type *X64RetType = X64Ty->getReturnType();
524
525 bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy();
526 unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
527 unsigned PassthroughArgSize =
528 (F->isVarArg() ? 5 : Thunk->arg_size()) - ThunkArgOffset;
529 assert(ArgTranslations.size() == (F->isVarArg() ? 5 : PassthroughArgSize));
530
531 // Translate arguments to call.
532 SmallVector<Value *> Args;
533 for (unsigned i = 0; i != PassthroughArgSize; ++i) {
534 Value *Arg = Thunk->getArg(i + ThunkArgOffset);
535 Type *ArgTy = Arm64Ty->getParamType(i);
536 ThunkArgTranslation ArgTranslation = ArgTranslations[i];
537 if (ArgTranslation != ThunkArgTranslation::Direct) {
538 // Translate array/struct arguments to the expected type.
539 if (ArgTranslation == ThunkArgTranslation::Bitcast) {
540 Value *CastAlloca = IRB.CreateAlloca(ArgTy);
541 IRB.CreateStore(Arg, CastAlloca);
542 Arg = IRB.CreateLoad(ArgTy, CastAlloca);
543 } else {
544 assert(ArgTranslation == ThunkArgTranslation::PointerIndirection);
545 Arg = IRB.CreateLoad(ArgTy, Arg);
546 }
547 }
548 assert(Arg->getType() == ArgTy);
549 Args.push_back(Arg);
550 }
551
552 if (F->isVarArg()) {
553 // The 5th argument to variadic entry thunks is used to model the x64 sp
554 // which is passed to the thunk in x4, this can be passed to the callee as
555 // the variadic argument start address after skipping over the 32 byte
556 // shadow store.
557
558 // The EC thunk CC will assign any argument marked as InReg to x4.
559 Thunk->addParamAttr(5, Attribute::InReg);
560 Value *Arg = Thunk->getArg(5);
561 Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20));
562 Args.push_back(Arg);
563
564 // Pass in a zero variadic argument size (in x5).
565 Args.push_back(IRB.getInt64(0));
566 }
567
568 // Call the function passed to the thunk.
569 Value *Callee = Thunk->getArg(0);
570 CallInst *Call = IRB.CreateCall(Arm64Ty, Callee, Args);
571
572 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
573 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
574 if (SRetAttr.isValid() && !InRegAttr.isValid()) {
575 Thunk->addParamAttr(1, SRetAttr);
576 Call->addParamAttr(0, SRetAttr);
577 }
578
579 Value *RetVal = Call;
580 if (TransformDirectToSRet) {
581 IRB.CreateStore(RetVal, Thunk->getArg(1));
582 } else if (X64RetType != RetTy) {
583 Value *CastAlloca = IRB.CreateAlloca(X64RetType);
584 IRB.CreateStore(Call, CastAlloca);
585 RetVal = IRB.CreateLoad(X64RetType, CastAlloca);
586 }
587
588 // Return to the caller. Note that the isel has code to translate this
589 // "ret" to a tail call to __os_arm64x_dispatch_ret. (Alternatively, we
590 // could emit a tail call here, but that would require a dedicated calling
591 // convention, which seems more complicated overall.)
592 if (X64RetType->isVoidTy())
593 IRB.CreateRetVoid();
594 else
595 IRB.CreateRet(RetVal);
596
597 return Thunk;
598 }
599
600 // Builds the "guest exit thunk", a helper to call a function which may or may
601 // not be an exit thunk. (We optimistically assume non-dllimport function
602 // declarations refer to functions defined in AArch64 code; if the linker
603 // can't prove that, we use this routine instead.)
buildGuestExitThunk(Function * F)604 Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
605 llvm::raw_null_ostream NullThunkName;
606 FunctionType *Arm64Ty, *X64Ty;
607 SmallVector<ThunkArgTranslation> ArgTranslations;
608 getThunkType(F->getFunctionType(), F->getAttributes(),
609 Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
610 ArgTranslations);
611 auto MangledName = getArm64ECMangledFunctionName(F->getName().str());
612 assert(MangledName && "Can't guest exit to function that's already native");
613 std::string ThunkName = *MangledName;
614 if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
615 ThunkName.insert(ThunkName.find("@"), "$exit_thunk");
616 } else {
617 ThunkName.append("$exit_thunk");
618 }
619 Function *GuestExit =
620 Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
621 GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
622 GuestExit->setSection(".wowthk$aa");
623 GuestExit->addMetadata(
624 "arm64ec_unmangled_name",
625 *MDNode::get(M->getContext(),
626 MDString::get(M->getContext(), F->getName())));
627 GuestExit->setMetadata(
628 "arm64ec_ecmangled_name",
629 MDNode::get(M->getContext(),
630 MDString::get(M->getContext(), *MangledName)));
631 F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {}));
632 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
633 IRBuilder<> B(BB);
634
635 // Load the global symbol as a pointer to the check function.
636 Value *GuardFn;
637 if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf"))
638 GuardFn = GuardFnCFGlobal;
639 else
640 GuardFn = GuardFnGlobal;
641 LoadInst *GuardCheckLoad = B.CreateLoad(PtrTy, GuardFn);
642
643 // Create new call instruction. The CFGuard check should always be a call,
644 // even if the original CallBase is an Invoke or CallBr instruction.
645 Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes());
646 CallInst *GuardCheck = B.CreateCall(
647 GuardFnType, GuardCheckLoad, {F, Thunk});
648
649 // Ensure that the first argument is passed in the correct register.
650 GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
651
652 SmallVector<Value *> Args(llvm::make_pointer_range(GuestExit->args()));
653 CallInst *Call = B.CreateCall(Arm64Ty, GuardCheck, Args);
654 Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
655
656 if (Call->getType()->isVoidTy())
657 B.CreateRetVoid();
658 else
659 B.CreateRet(Call);
660
661 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
662 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
663 if (SRetAttr.isValid() && !InRegAttr.isValid()) {
664 GuestExit->addParamAttr(0, SRetAttr);
665 Call->addParamAttr(0, SRetAttr);
666 }
667
668 return GuestExit;
669 }
670
671 Function *
buildPatchableThunk(GlobalAlias * UnmangledAlias,GlobalAlias * MangledAlias)672 AArch64Arm64ECCallLowering::buildPatchableThunk(GlobalAlias *UnmangledAlias,
673 GlobalAlias *MangledAlias) {
674 llvm::raw_null_ostream NullThunkName;
675 FunctionType *Arm64Ty, *X64Ty;
676 Function *F = cast<Function>(MangledAlias->getAliasee());
677 SmallVector<ThunkArgTranslation> ArgTranslations;
678 getThunkType(F->getFunctionType(), F->getAttributes(),
679 Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
680 ArgTranslations);
681 std::string ThunkName(MangledAlias->getName());
682 if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
683 ThunkName.insert(ThunkName.find("@"), "$hybpatch_thunk");
684 } else {
685 ThunkName.append("$hybpatch_thunk");
686 }
687
688 Function *GuestExit =
689 Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
690 GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
691 GuestExit->setSection(".wowthk$aa");
692 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
693 IRBuilder<> B(BB);
694
695 // Load the global symbol as a pointer to the check function.
696 LoadInst *DispatchLoad = B.CreateLoad(PtrTy, DispatchFnGlobal);
697
698 // Create new dispatch call instruction.
699 Function *ExitThunk =
700 buildExitThunk(F->getFunctionType(), F->getAttributes());
701 CallInst *Dispatch =
702 B.CreateCall(DispatchFnType, DispatchLoad,
703 {UnmangledAlias, ExitThunk, UnmangledAlias->getAliasee()});
704
705 // Ensure that the first arguments are passed in the correct registers.
706 Dispatch->setCallingConv(CallingConv::CFGuard_Check);
707
708 SmallVector<Value *> Args(llvm::make_pointer_range(GuestExit->args()));
709 CallInst *Call = B.CreateCall(Arm64Ty, Dispatch, Args);
710 Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
711
712 if (Call->getType()->isVoidTy())
713 B.CreateRetVoid();
714 else
715 B.CreateRet(Call);
716
717 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
718 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
719 if (SRetAttr.isValid() && !InRegAttr.isValid()) {
720 GuestExit->addParamAttr(0, SRetAttr);
721 Call->addParamAttr(0, SRetAttr);
722 }
723
724 MangledAlias->setAliasee(GuestExit);
725 return GuestExit;
726 }
727
728 // Lower an indirect call with inline code.
lowerCall(CallBase * CB)729 void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
730 assert(CB->getModule()->getTargetTriple().isOSWindows() &&
731 "Only applicable for Windows targets");
732
733 IRBuilder<> B(CB);
734 Value *CalledOperand = CB->getCalledOperand();
735
736 // If the indirect call is called within catchpad or cleanuppad,
737 // we need to copy "funclet" bundle of the call.
738 SmallVector<llvm::OperandBundleDef, 1> Bundles;
739 if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet))
740 Bundles.push_back(OperandBundleDef(*Bundle));
741
742 // Load the global symbol as a pointer to the check function.
743 Value *GuardFn;
744 if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf"))
745 GuardFn = GuardFnCFGlobal;
746 else
747 GuardFn = GuardFnGlobal;
748 LoadInst *GuardCheckLoad = B.CreateLoad(PtrTy, GuardFn);
749
750 // Create new call instruction. The CFGuard check should always be a call,
751 // even if the original CallBase is an Invoke or CallBr instruction.
752 Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes());
753 CallInst *GuardCheck =
754 B.CreateCall(GuardFnType, GuardCheckLoad, {CalledOperand, Thunk},
755 Bundles);
756
757 // Ensure that the first argument is passed in the correct register.
758 GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
759
760 CB->setCalledOperand(GuardCheck);
761 }
762
runOnModule(Module & Mod)763 bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
764 if (!GenerateThunks)
765 return false;
766
767 M = &Mod;
768
769 // Check if this module has the cfguard flag and read its value.
770 if (auto *MD =
771 mdconst::extract_or_null<ConstantInt>(M->getModuleFlag("cfguard")))
772 cfguard_module_flag = MD->getZExtValue();
773
774 PtrTy = PointerType::getUnqual(M->getContext());
775 I64Ty = Type::getInt64Ty(M->getContext());
776 VoidTy = Type::getVoidTy(M->getContext());
777
778 GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
779 DispatchFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy, PtrTy}, false);
780 GuardFnCFGlobal = M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", PtrTy);
781 GuardFnGlobal = M->getOrInsertGlobal("__os_arm64x_check_icall", PtrTy);
782 DispatchFnGlobal = M->getOrInsertGlobal("__os_arm64x_dispatch_call", PtrTy);
783
784 // Mangle names of function aliases and add the alias name to
785 // arm64ec_unmangled_name metadata to ensure a weak anti-dependency symbol is
786 // emitted for the alias as well. Do this early, before handling
787 // hybrid_patchable functions, to avoid mangling their aliases.
788 for (GlobalAlias &A : Mod.aliases()) {
789 auto F = dyn_cast_or_null<Function>(A.getAliaseeObject());
790 if (!F)
791 continue;
792 if (std::optional<std::string> MangledName =
793 getArm64ECMangledFunctionName(A.getName().str())) {
794 F->addMetadata("arm64ec_unmangled_name",
795 *MDNode::get(M->getContext(),
796 MDString::get(M->getContext(), A.getName())));
797 A.setName(MangledName.value());
798 }
799 }
800
801 DenseMap<GlobalAlias *, GlobalAlias *> FnsMap;
802 SetVector<GlobalAlias *> PatchableFns;
803
804 for (Function &F : Mod) {
805 if (F.hasPersonalityFn()) {
806 GlobalValue *PersFn =
807 cast<GlobalValue>(F.getPersonalityFn()->stripPointerCasts());
808 if (PersFn->getValueType() && PersFn->getValueType()->isFunctionTy()) {
809 if (std::optional<std::string> MangledName =
810 getArm64ECMangledFunctionName(PersFn->getName().str())) {
811 PersFn->setName(MangledName.value());
812 }
813 }
814 }
815
816 if (!F.hasFnAttribute(Attribute::HybridPatchable) || F.isDeclaration() ||
817 F.hasLocalLinkage() ||
818 F.getName().ends_with(HybridPatchableTargetSuffix))
819 continue;
820
821 // Rename hybrid patchable functions and change callers to use a global
822 // alias instead.
823 if (std::optional<std::string> MangledName =
824 getArm64ECMangledFunctionName(F.getName().str())) {
825 std::string OrigName(F.getName());
826 F.setName(MangledName.value() + HybridPatchableTargetSuffix);
827
828 // The unmangled symbol is a weak alias to an undefined symbol with the
829 // "EXP+" prefix. This undefined symbol is resolved by the linker by
830 // creating an x86 thunk that jumps back to the actual EC target. Since we
831 // can't represent that in IR, we create an alias to the target instead.
832 // The "EXP+" symbol is set as metadata, which is then used by
833 // emitGlobalAlias to emit the right alias.
834 auto *A =
835 GlobalAlias::create(GlobalValue::LinkOnceODRLinkage, OrigName, &F);
836 auto *AM = GlobalAlias::create(GlobalValue::LinkOnceODRLinkage,
837 MangledName.value(), &F);
838 F.replaceUsesWithIf(AM,
839 [](Use &U) { return isa<GlobalAlias>(U.getUser()); });
840 F.replaceAllUsesWith(A);
841 F.setMetadata("arm64ec_exp_name",
842 MDNode::get(M->getContext(),
843 MDString::get(M->getContext(),
844 "EXP+" + MangledName.value())));
845 A->setAliasee(&F);
846 AM->setAliasee(&F);
847
848 if (F.hasDLLExportStorageClass()) {
849 A->setDLLStorageClass(GlobalValue::DLLExportStorageClass);
850 F.setDLLStorageClass(GlobalValue::DefaultStorageClass);
851 }
852
853 FnsMap[A] = AM;
854 PatchableFns.insert(A);
855 }
856 }
857
858 SetVector<GlobalValue *> DirectCalledFns;
859 for (Function &F : Mod)
860 if (!F.isDeclaration() &&
861 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
862 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)
863 processFunction(F, DirectCalledFns, FnsMap);
864
865 struct ThunkInfo {
866 Constant *Src;
867 Constant *Dst;
868 Arm64ECThunkType Kind;
869 };
870 SmallVector<ThunkInfo> ThunkMapping;
871 for (Function &F : Mod) {
872 if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) &&
873 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
874 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) {
875 if (!F.hasComdat())
876 F.setComdat(Mod.getOrInsertComdat(F.getName()));
877 ThunkMapping.push_back(
878 {&F, buildEntryThunk(&F), Arm64ECThunkType::Entry});
879 }
880 }
881 for (GlobalValue *O : DirectCalledFns) {
882 auto GA = dyn_cast<GlobalAlias>(O);
883 auto F = dyn_cast<Function>(GA ? GA->getAliasee() : O);
884 ThunkMapping.push_back(
885 {O, buildExitThunk(F->getFunctionType(), F->getAttributes()),
886 Arm64ECThunkType::Exit});
887 if (!GA && !F->hasDLLImportStorageClass())
888 ThunkMapping.push_back(
889 {buildGuestExitThunk(F), F, Arm64ECThunkType::GuestExit});
890 }
891 for (GlobalAlias *A : PatchableFns) {
892 Function *Thunk = buildPatchableThunk(A, FnsMap[A]);
893 ThunkMapping.push_back({Thunk, A, Arm64ECThunkType::GuestExit});
894 }
895
896 if (!ThunkMapping.empty()) {
897 SmallVector<Constant *> ThunkMappingArrayElems;
898 for (ThunkInfo &Thunk : ThunkMapping) {
899 ThunkMappingArrayElems.push_back(ConstantStruct::getAnon(
900 {Thunk.Src, Thunk.Dst,
901 ConstantInt::get(M->getContext(), APInt(32, uint8_t(Thunk.Kind)))}));
902 }
903 Constant *ThunkMappingArray = ConstantArray::get(
904 llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(),
905 ThunkMappingArrayElems.size()),
906 ThunkMappingArrayElems);
907 new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false,
908 GlobalValue::ExternalLinkage, ThunkMappingArray,
909 "llvm.arm64ec.symbolmap");
910 }
911
912 return true;
913 }
914
processFunction(Function & F,SetVector<GlobalValue * > & DirectCalledFns,DenseMap<GlobalAlias *,GlobalAlias * > & FnsMap)915 bool AArch64Arm64ECCallLowering::processFunction(
916 Function &F, SetVector<GlobalValue *> &DirectCalledFns,
917 DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap) {
918 SmallVector<CallBase *, 8> IndirectCalls;
919
920 // For ARM64EC targets, a function definition's name is mangled differently
921 // from the normal symbol. We currently have no representation of this sort
922 // of symbol in IR, so we change the name to the mangled name, then store
923 // the unmangled name as metadata. Later passes that need the unmangled
924 // name (emitting the definition) can grab it from the metadata.
925 //
926 // FIXME: Handle functions with weak linkage?
927 if (!F.hasLocalLinkage() || F.hasAddressTaken()) {
928 if (std::optional<std::string> MangledName =
929 getArm64ECMangledFunctionName(F.getName().str())) {
930 F.addMetadata("arm64ec_unmangled_name",
931 *MDNode::get(M->getContext(),
932 MDString::get(M->getContext(), F.getName())));
933 if (F.hasComdat() && F.getComdat()->getName() == F.getName()) {
934 Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value());
935 SmallVector<GlobalObject *> ComdatUsers =
936 to_vector(F.getComdat()->getUsers());
937 for (GlobalObject *User : ComdatUsers)
938 User->setComdat(MangledComdat);
939 }
940 F.setName(MangledName.value());
941 }
942 }
943
944 // Iterate over the instructions to find all indirect call/invoke/callbr
945 // instructions. Make a separate list of pointers to indirect
946 // call/invoke/callbr instructions because the original instructions will be
947 // deleted as the checks are added.
948 for (BasicBlock &BB : F) {
949 for (Instruction &I : BB) {
950 auto *CB = dyn_cast<CallBase>(&I);
951 if (!CB || CB->getCallingConv() == CallingConv::ARM64EC_Thunk_X64 ||
952 CB->isInlineAsm())
953 continue;
954
955 // We need to instrument any call that isn't directly calling an
956 // ARM64 function.
957 //
958 // FIXME: getCalledFunction() fails if there's a bitcast (e.g.
959 // unprototyped functions in C)
960 if (Function *F = CB->getCalledFunction()) {
961 if (!LowerDirectToIndirect || F->hasLocalLinkage() ||
962 F->isIntrinsic() || !F->isDeclaration())
963 continue;
964
965 DirectCalledFns.insert(F);
966 continue;
967 }
968
969 // Use mangled global alias for direct calls to patchable functions.
970 if (GlobalAlias *A = dyn_cast<GlobalAlias>(CB->getCalledOperand())) {
971 auto I = FnsMap.find(A);
972 if (I != FnsMap.end()) {
973 CB->setCalledOperand(I->second);
974 DirectCalledFns.insert(I->first);
975 continue;
976 }
977 }
978
979 IndirectCalls.push_back(CB);
980 ++Arm64ECCallsLowered;
981 }
982 }
983
984 if (IndirectCalls.empty())
985 return false;
986
987 for (CallBase *CB : IndirectCalls)
988 lowerCall(CB);
989
990 return true;
991 }
992
993 char AArch64Arm64ECCallLowering::ID = 0;
994 INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering",
995 "AArch64Arm64ECCallLowering", false, false)
996
createAArch64Arm64ECCallLoweringPass()997 ModulePass *llvm::createAArch64Arm64ECCallLoweringPass() {
998 return new AArch64Arm64ECCallLowering;
999 }
1000