xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp (revision 9c77fb6aaa366cbabc80ee1b834bcfe4df135491)
1 //===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file contains the IR transform to lower external or indirect calls for
11 /// the ARM64EC calling convention. Such calls must go through the runtime, so
12 /// we can translate the calling convention for calls into the emulator.
13 ///
14 /// This subsumes Control Flow Guard handling.
15 ///
16 //===----------------------------------------------------------------------===//
17 
18 #include "AArch64.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/IR/CallingConv.h"
24 #include "llvm/IR/GlobalAlias.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Mangler.h"
28 #include "llvm/IR/Module.h"
29 #include "llvm/Object/COFF.h"
30 #include "llvm/Pass.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/TargetParser/Triple.h"
33 
34 using namespace llvm;
35 using namespace llvm::COFF;
36 
37 using OperandBundleDef = OperandBundleDefT<Value *>;
38 
39 #define DEBUG_TYPE "arm64eccalllowering"
40 
41 STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered");
42 
43 static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect",
44                                            cl::Hidden, cl::init(true));
45 static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden,
46                                     cl::init(true));
47 
48 namespace {
49 
50 enum ThunkArgTranslation : uint8_t {
51   Direct,
52   Bitcast,
53   PointerIndirection,
54 };
55 
56 struct ThunkArgInfo {
57   Type *Arm64Ty;
58   Type *X64Ty;
59   ThunkArgTranslation Translation;
60 };
61 
62 class AArch64Arm64ECCallLowering : public ModulePass {
63 public:
64   static char ID;
65   AArch64Arm64ECCallLowering() : ModulePass(ID) {}
66 
67   Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs);
68   Function *buildEntryThunk(Function *F);
69   void lowerCall(CallBase *CB);
70   Function *buildGuestExitThunk(Function *F);
71   Function *buildPatchableThunk(GlobalAlias *UnmangledAlias,
72                                 GlobalAlias *MangledAlias);
73   bool processFunction(Function &F, SetVector<GlobalValue *> &DirectCalledFns,
74                        DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap);
75   bool runOnModule(Module &M) override;
76 
77 private:
78   int cfguard_module_flag = 0;
79   FunctionType *GuardFnType = nullptr;
80   FunctionType *DispatchFnType = nullptr;
81   Constant *GuardFnCFGlobal = nullptr;
82   Constant *GuardFnGlobal = nullptr;
83   Constant *DispatchFnGlobal = nullptr;
84   Module *M = nullptr;
85 
86   Type *PtrTy;
87   Type *I64Ty;
88   Type *VoidTy;
89 
90   void getThunkType(FunctionType *FT, AttributeList AttrList,
91                     Arm64ECThunkType TT, raw_ostream &Out,
92                     FunctionType *&Arm64Ty, FunctionType *&X64Ty,
93                     SmallVector<ThunkArgTranslation> &ArgTranslations);
94   void getThunkRetType(FunctionType *FT, AttributeList AttrList,
95                        raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy,
96                        SmallVectorImpl<Type *> &Arm64ArgTypes,
97                        SmallVectorImpl<Type *> &X64ArgTypes,
98                        SmallVector<ThunkArgTranslation> &ArgTranslations,
99                        bool &HasSretPtr);
100   void getThunkArgTypes(FunctionType *FT, AttributeList AttrList,
101                         Arm64ECThunkType TT, raw_ostream &Out,
102                         SmallVectorImpl<Type *> &Arm64ArgTypes,
103                         SmallVectorImpl<Type *> &X64ArgTypes,
104                         SmallVectorImpl<ThunkArgTranslation> &ArgTranslations,
105                         bool HasSretPtr);
106   ThunkArgInfo canonicalizeThunkType(Type *T, Align Alignment, bool Ret,
107                                      uint64_t ArgSizeBytes, raw_ostream &Out);
108 };
109 
110 } // end anonymous namespace
111 
112 void AArch64Arm64ECCallLowering::getThunkType(
113     FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
114     raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty,
115     SmallVector<ThunkArgTranslation> &ArgTranslations) {
116   Out << (TT == Arm64ECThunkType::Entry ? "$ientry_thunk$cdecl$"
117                                         : "$iexit_thunk$cdecl$");
118 
119   Type *Arm64RetTy;
120   Type *X64RetTy;
121 
122   SmallVector<Type *> Arm64ArgTypes;
123   SmallVector<Type *> X64ArgTypes;
124 
125   // The first argument to a thunk is the called function, stored in x9.
126   // For exit thunks, we pass the called function down to the emulator;
127   // for entry/guest exit thunks, we just call the Arm64 function directly.
128   if (TT == Arm64ECThunkType::Exit)
129     Arm64ArgTypes.push_back(PtrTy);
130   X64ArgTypes.push_back(PtrTy);
131 
132   bool HasSretPtr = false;
133   getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
134                   X64ArgTypes, ArgTranslations, HasSretPtr);
135 
136   getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,
137                    ArgTranslations, HasSretPtr);
138 
139   Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false);
140 
141   X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false);
142 }
143 
144 void AArch64Arm64ECCallLowering::getThunkArgTypes(
145     FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
146     raw_ostream &Out, SmallVectorImpl<Type *> &Arm64ArgTypes,
147     SmallVectorImpl<Type *> &X64ArgTypes,
148     SmallVectorImpl<ThunkArgTranslation> &ArgTranslations, bool HasSretPtr) {
149 
150   Out << "$";
151   if (FT->isVarArg()) {
152     // We treat the variadic function's thunk as a normal function
153     // with the following type on the ARM side:
154     //   rettype exitthunk(
155     //     ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5)
156     //
157     // that can coverage all types of variadic function.
158     // x9 is similar to normal exit thunk, store the called function.
159     // x0-x3 is the arguments be stored in registers.
160     // x4 is the address of the arguments on the stack.
161     // x5 is the size of the arguments on the stack.
162     //
163     // On the x64 side, it's the same except that x5 isn't set.
164     //
165     // If both the ARM and X64 sides are sret, there are only three
166     // arguments in registers.
167     //
168     // If the X64 side is sret, but the ARM side isn't, we pass an extra value
169     // to/from the X64 side, and let SelectionDAG transform it into a memory
170     // location.
171     Out << "varargs";
172 
173     // x0-x3
174     for (int i = HasSretPtr ? 1 : 0; i < 4; i++) {
175       Arm64ArgTypes.push_back(I64Ty);
176       X64ArgTypes.push_back(I64Ty);
177       ArgTranslations.push_back(ThunkArgTranslation::Direct);
178     }
179 
180     // x4
181     Arm64ArgTypes.push_back(PtrTy);
182     X64ArgTypes.push_back(PtrTy);
183     ArgTranslations.push_back(ThunkArgTranslation::Direct);
184     // x5
185     Arm64ArgTypes.push_back(I64Ty);
186     if (TT != Arm64ECThunkType::Entry) {
187       // FIXME: x5 isn't actually used by the x64 side; revisit once we
188       // have proper isel for varargs
189       X64ArgTypes.push_back(I64Ty);
190       ArgTranslations.push_back(ThunkArgTranslation::Direct);
191     }
192     return;
193   }
194 
195   unsigned I = 0;
196   if (HasSretPtr)
197     I++;
198 
199   if (I == FT->getNumParams()) {
200     Out << "v";
201     return;
202   }
203 
204   for (unsigned E = FT->getNumParams(); I != E; ++I) {
205 #if 0
206     // FIXME: Need more information about argument size; see
207     // https://reviews.llvm.org/D132926
208     uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I);
209     Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne();
210 #else
211     uint64_t ArgSizeBytes = 0;
212     Align ParamAlign = Align();
213 #endif
214     auto [Arm64Ty, X64Ty, ArgTranslation] =
215         canonicalizeThunkType(FT->getParamType(I), ParamAlign,
216                               /*Ret*/ false, ArgSizeBytes, Out);
217     Arm64ArgTypes.push_back(Arm64Ty);
218     X64ArgTypes.push_back(X64Ty);
219     ArgTranslations.push_back(ArgTranslation);
220   }
221 }
222 
223 void AArch64Arm64ECCallLowering::getThunkRetType(
224     FunctionType *FT, AttributeList AttrList, raw_ostream &Out,
225     Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes,
226     SmallVectorImpl<Type *> &X64ArgTypes,
227     SmallVector<ThunkArgTranslation> &ArgTranslations, bool &HasSretPtr) {
228   Type *T = FT->getReturnType();
229 #if 0
230   // FIXME: Need more information about argument size; see
231   // https://reviews.llvm.org/D132926
232   uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes();
233 #else
234   int64_t ArgSizeBytes = 0;
235 #endif
236   if (T->isVoidTy()) {
237     if (FT->getNumParams()) {
238       Attribute SRetAttr0 = AttrList.getParamAttr(0, Attribute::StructRet);
239       Attribute InRegAttr0 = AttrList.getParamAttr(0, Attribute::InReg);
240       Attribute SRetAttr1, InRegAttr1;
241       if (FT->getNumParams() > 1) {
242         // Also check the second parameter (for class methods, the first
243         // parameter is "this", and the second parameter is the sret pointer.)
244         // It doesn't matter which one is sret.
245         SRetAttr1 = AttrList.getParamAttr(1, Attribute::StructRet);
246         InRegAttr1 = AttrList.getParamAttr(1, Attribute::InReg);
247       }
248       if ((SRetAttr0.isValid() && InRegAttr0.isValid()) ||
249           (SRetAttr1.isValid() && InRegAttr1.isValid())) {
250         // sret+inreg indicates a call that returns a C++ class value. This is
251         // actually equivalent to just passing and returning a void* pointer
252         // as the first or second argument. Translate it that way, instead of
253         // trying to model "inreg" in the thunk's calling convention; this
254         // simplfies the rest of the code, and matches MSVC mangling.
255         Out << "i8";
256         Arm64RetTy = I64Ty;
257         X64RetTy = I64Ty;
258         return;
259       }
260       if (SRetAttr0.isValid()) {
261         // FIXME: Sanity-check the sret type; if it's an integer or pointer,
262         // we'll get screwy mangling/codegen.
263         // FIXME: For large struct types, mangle as an integer argument and
264         // integer return, so we can reuse more thunks, instead of "m" syntax.
265         // (MSVC mangles this case as an integer return with no argument, but
266         // that's a miscompile.)
267         Type *SRetType = SRetAttr0.getValueAsType();
268         Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne();
269         canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes,
270                               Out);
271         Arm64RetTy = VoidTy;
272         X64RetTy = VoidTy;
273         Arm64ArgTypes.push_back(FT->getParamType(0));
274         X64ArgTypes.push_back(FT->getParamType(0));
275         ArgTranslations.push_back(ThunkArgTranslation::Direct);
276         HasSretPtr = true;
277         return;
278       }
279     }
280 
281     Out << "v";
282     Arm64RetTy = VoidTy;
283     X64RetTy = VoidTy;
284     return;
285   }
286 
287   auto info =
288       canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out);
289   Arm64RetTy = info.Arm64Ty;
290   X64RetTy = info.X64Ty;
291   if (X64RetTy->isPointerTy()) {
292     // If the X64 type is canonicalized to a pointer, that means it's
293     // passed/returned indirectly. For a return value, that means it's an
294     // sret pointer.
295     X64ArgTypes.push_back(X64RetTy);
296     X64RetTy = VoidTy;
297   }
298 }
299 
300 ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType(
301     Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes,
302     raw_ostream &Out) {
303 
304   auto direct = [](Type *T) {
305     return ThunkArgInfo{T, T, ThunkArgTranslation::Direct};
306   };
307 
308   auto bitcast = [this](Type *Arm64Ty, uint64_t SizeInBytes) {
309     return ThunkArgInfo{Arm64Ty,
310                         llvm::Type::getIntNTy(M->getContext(), SizeInBytes * 8),
311                         ThunkArgTranslation::Bitcast};
312   };
313 
314   auto pointerIndirection = [this](Type *Arm64Ty) {
315     return ThunkArgInfo{Arm64Ty, PtrTy,
316                         ThunkArgTranslation::PointerIndirection};
317   };
318 
319   if (T->isFloatTy()) {
320     Out << "f";
321     return direct(T);
322   }
323 
324   if (T->isDoubleTy()) {
325     Out << "d";
326     return direct(T);
327   }
328 
329   if (T->isFloatingPointTy()) {
330     report_fatal_error(
331         "Only 32 and 64 bit floating points are supported for ARM64EC thunks");
332   }
333 
334   auto &DL = M->getDataLayout();
335 
336   if (auto *StructTy = dyn_cast<StructType>(T))
337     if (StructTy->getNumElements() == 1)
338       T = StructTy->getElementType(0);
339 
340   if (T->isArrayTy()) {
341     Type *ElementTy = T->getArrayElementType();
342     uint64_t ElementCnt = T->getArrayNumElements();
343     uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8;
344     uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes;
345     if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) {
346       Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes;
347       if (Alignment.value() >= 16 && !Ret)
348         Out << "a" << Alignment.value();
349       if (TotalSizeBytes <= 8) {
350         // Arm64 returns small structs of float/double in float registers;
351         // X64 uses RAX.
352         return bitcast(T, TotalSizeBytes);
353       } else {
354         // Struct is passed directly on Arm64, but indirectly on X64.
355         return pointerIndirection(T);
356       }
357     } else if (T->isFloatingPointTy()) {
358       report_fatal_error("Only 32 and 64 bit floating points are supported for "
359                          "ARM64EC thunks");
360     }
361   }
362 
363   if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) {
364     Out << "i8";
365     return direct(I64Ty);
366   }
367 
368   unsigned TypeSize = ArgSizeBytes;
369   if (TypeSize == 0)
370     TypeSize = DL.getTypeSizeInBits(T) / 8;
371   Out << "m";
372   if (TypeSize != 4)
373     Out << TypeSize;
374   if (Alignment.value() >= 16 && !Ret)
375     Out << "a" << Alignment.value();
376   // FIXME: Try to canonicalize Arm64Ty more thoroughly?
377   if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) {
378     // Pass directly in an integer register
379     return bitcast(T, TypeSize);
380   } else {
381     // Passed directly on Arm64, but indirectly on X64.
382     return pointerIndirection(T);
383   }
384 }
385 
386 // This function builds the "exit thunk", a function which translates
387 // arguments and return values when calling x64 code from AArch64 code.
388 Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
389                                                      AttributeList Attrs) {
390   SmallString<256> ExitThunkName;
391   llvm::raw_svector_ostream ExitThunkStream(ExitThunkName);
392   FunctionType *Arm64Ty, *X64Ty;
393   SmallVector<ThunkArgTranslation> ArgTranslations;
394   getThunkType(FT, Attrs, Arm64ECThunkType::Exit, ExitThunkStream, Arm64Ty,
395                X64Ty, ArgTranslations);
396   if (Function *F = M->getFunction(ExitThunkName))
397     return F;
398 
399   Function *F = Function::Create(Arm64Ty, GlobalValue::LinkOnceODRLinkage, 0,
400                                  ExitThunkName, M);
401   F->setCallingConv(CallingConv::ARM64EC_Thunk_Native);
402   F->setSection(".wowthk$aa");
403   F->setComdat(M->getOrInsertComdat(ExitThunkName));
404   // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
405   F->addFnAttr("frame-pointer", "all");
406   // Only copy sret from the first argument. For C++ instance methods, clang can
407   // stick an sret marking on a later argument, but it doesn't actually affect
408   // the ABI, so we can omit it. This avoids triggering a verifier assertion.
409   if (FT->getNumParams()) {
410     auto SRet = Attrs.getParamAttr(0, Attribute::StructRet);
411     auto InReg = Attrs.getParamAttr(0, Attribute::InReg);
412     if (SRet.isValid() && !InReg.isValid())
413       F->addParamAttr(1, SRet);
414   }
415   // FIXME: Copy anything other than sret?  Shouldn't be necessary for normal
416   // C ABI, but might show up in other cases.
417   BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F);
418   IRBuilder<> IRB(BB);
419   Value *CalleePtr =
420       M->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy);
421   Value *Callee = IRB.CreateLoad(PtrTy, CalleePtr);
422   auto &DL = M->getDataLayout();
423   SmallVector<Value *> Args;
424 
425   // Pass the called function in x9.
426   auto X64TyOffset = 1;
427   Args.push_back(F->arg_begin());
428 
429   Type *RetTy = Arm64Ty->getReturnType();
430   if (RetTy != X64Ty->getReturnType()) {
431     // If the return type is an array or struct, translate it. Values of size
432     // 8 or less go into RAX; bigger values go into memory, and we pass a
433     // pointer.
434     if (DL.getTypeStoreSize(RetTy) > 8) {
435       Args.push_back(IRB.CreateAlloca(RetTy));
436       X64TyOffset++;
437     }
438   }
439 
440   for (auto [Arg, X64ArgType, ArgTranslation] : llvm::zip_equal(
441            make_range(F->arg_begin() + 1, F->arg_end()),
442            make_range(X64Ty->param_begin() + X64TyOffset, X64Ty->param_end()),
443            ArgTranslations)) {
444     // Translate arguments from AArch64 calling convention to x86 calling
445     // convention.
446     //
447     // For simple types, we don't need to do any translation: they're
448     // represented the same way. (Implicit sign extension is not part of
449     // either convention.)
450     //
451     // The big thing we have to worry about is struct types... but
452     // fortunately AArch64 clang is pretty friendly here: the cases that need
453     // translation are always passed as a struct or array. (If we run into
454     // some cases where this doesn't work, we can teach clang to mark it up
455     // with an attribute.)
456     //
457     // The first argument is the called function, stored in x9.
458     if (ArgTranslation != ThunkArgTranslation::Direct) {
459       Value *Mem = IRB.CreateAlloca(Arg.getType());
460       IRB.CreateStore(&Arg, Mem);
461       if (ArgTranslation == ThunkArgTranslation::Bitcast) {
462         Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType()));
463         Args.push_back(IRB.CreateLoad(IntTy, Mem));
464       } else {
465         assert(ArgTranslation == ThunkArgTranslation::PointerIndirection);
466         Args.push_back(Mem);
467       }
468     } else {
469       Args.push_back(&Arg);
470     }
471     assert(Args.back()->getType() == X64ArgType);
472   }
473   // FIXME: Transfer necessary attributes? sret? anything else?
474 
475   CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args);
476   Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
477 
478   Value *RetVal = Call;
479   if (RetTy != X64Ty->getReturnType()) {
480     // If we rewrote the return type earlier, convert the return value to
481     // the proper type.
482     if (DL.getTypeStoreSize(RetTy) > 8) {
483       RetVal = IRB.CreateLoad(RetTy, Args[1]);
484     } else {
485       Value *CastAlloca = IRB.CreateAlloca(RetTy);
486       IRB.CreateStore(Call, CastAlloca);
487       RetVal = IRB.CreateLoad(RetTy, CastAlloca);
488     }
489   }
490 
491   if (RetTy->isVoidTy())
492     IRB.CreateRetVoid();
493   else
494     IRB.CreateRet(RetVal);
495   return F;
496 }
497 
498 // This function builds the "entry thunk", a function which translates
499 // arguments and return values when calling AArch64 code from x64 code.
500 Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
501   SmallString<256> EntryThunkName;
502   llvm::raw_svector_ostream EntryThunkStream(EntryThunkName);
503   FunctionType *Arm64Ty, *X64Ty;
504   SmallVector<ThunkArgTranslation> ArgTranslations;
505   getThunkType(F->getFunctionType(), F->getAttributes(),
506                Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty,
507                ArgTranslations);
508   if (Function *F = M->getFunction(EntryThunkName))
509     return F;
510 
511   Function *Thunk = Function::Create(X64Ty, GlobalValue::LinkOnceODRLinkage, 0,
512                                      EntryThunkName, M);
513   Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
514   Thunk->setSection(".wowthk$aa");
515   Thunk->setComdat(M->getOrInsertComdat(EntryThunkName));
516   // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
517   Thunk->addFnAttr("frame-pointer", "all");
518 
519   BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk);
520   IRBuilder<> IRB(BB);
521 
522   Type *RetTy = Arm64Ty->getReturnType();
523   Type *X64RetType = X64Ty->getReturnType();
524 
525   bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy();
526   unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
527   unsigned PassthroughArgSize =
528       (F->isVarArg() ? 5 : Thunk->arg_size()) - ThunkArgOffset;
529   assert(ArgTranslations.size() == (F->isVarArg() ? 5 : PassthroughArgSize));
530 
531   // Translate arguments to call.
532   SmallVector<Value *> Args;
533   for (unsigned i = 0; i != PassthroughArgSize; ++i) {
534     Value *Arg = Thunk->getArg(i + ThunkArgOffset);
535     Type *ArgTy = Arm64Ty->getParamType(i);
536     ThunkArgTranslation ArgTranslation = ArgTranslations[i];
537     if (ArgTranslation != ThunkArgTranslation::Direct) {
538       // Translate array/struct arguments to the expected type.
539       if (ArgTranslation == ThunkArgTranslation::Bitcast) {
540         Value *CastAlloca = IRB.CreateAlloca(ArgTy);
541         IRB.CreateStore(Arg, CastAlloca);
542         Arg = IRB.CreateLoad(ArgTy, CastAlloca);
543       } else {
544         assert(ArgTranslation == ThunkArgTranslation::PointerIndirection);
545         Arg = IRB.CreateLoad(ArgTy, Arg);
546       }
547     }
548     assert(Arg->getType() == ArgTy);
549     Args.push_back(Arg);
550   }
551 
552   if (F->isVarArg()) {
553     // The 5th argument to variadic entry thunks is used to model the x64 sp
554     // which is passed to the thunk in x4, this can be passed to the callee as
555     // the variadic argument start address after skipping over the 32 byte
556     // shadow store.
557 
558     // The EC thunk CC will assign any argument marked as InReg to x4.
559     Thunk->addParamAttr(5, Attribute::InReg);
560     Value *Arg = Thunk->getArg(5);
561     Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20));
562     Args.push_back(Arg);
563 
564     // Pass in a zero variadic argument size (in x5).
565     Args.push_back(IRB.getInt64(0));
566   }
567 
568   // Call the function passed to the thunk.
569   Value *Callee = Thunk->getArg(0);
570   CallInst *Call = IRB.CreateCall(Arm64Ty, Callee, Args);
571 
572   auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
573   auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
574   if (SRetAttr.isValid() && !InRegAttr.isValid()) {
575     Thunk->addParamAttr(1, SRetAttr);
576     Call->addParamAttr(0, SRetAttr);
577   }
578 
579   Value *RetVal = Call;
580   if (TransformDirectToSRet) {
581     IRB.CreateStore(RetVal, Thunk->getArg(1));
582   } else if (X64RetType != RetTy) {
583     Value *CastAlloca = IRB.CreateAlloca(X64RetType);
584     IRB.CreateStore(Call, CastAlloca);
585     RetVal = IRB.CreateLoad(X64RetType, CastAlloca);
586   }
587 
588   // Return to the caller.  Note that the isel has code to translate this
589   // "ret" to a tail call to __os_arm64x_dispatch_ret.  (Alternatively, we
590   // could emit a tail call here, but that would require a dedicated calling
591   // convention, which seems more complicated overall.)
592   if (X64RetType->isVoidTy())
593     IRB.CreateRetVoid();
594   else
595     IRB.CreateRet(RetVal);
596 
597   return Thunk;
598 }
599 
600 // Builds the "guest exit thunk", a helper to call a function which may or may
601 // not be an exit thunk. (We optimistically assume non-dllimport function
602 // declarations refer to functions defined in AArch64 code; if the linker
603 // can't prove that, we use this routine instead.)
604 Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
605   llvm::raw_null_ostream NullThunkName;
606   FunctionType *Arm64Ty, *X64Ty;
607   SmallVector<ThunkArgTranslation> ArgTranslations;
608   getThunkType(F->getFunctionType(), F->getAttributes(),
609                Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
610                ArgTranslations);
611   auto MangledName = getArm64ECMangledFunctionName(F->getName().str());
612   assert(MangledName && "Can't guest exit to function that's already native");
613   std::string ThunkName = *MangledName;
614   if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
615     ThunkName.insert(ThunkName.find("@"), "$exit_thunk");
616   } else {
617     ThunkName.append("$exit_thunk");
618   }
619   Function *GuestExit =
620       Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
621   GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
622   GuestExit->setSection(".wowthk$aa");
623   GuestExit->addMetadata(
624       "arm64ec_unmangled_name",
625       *MDNode::get(M->getContext(),
626                    MDString::get(M->getContext(), F->getName())));
627   GuestExit->setMetadata(
628       "arm64ec_ecmangled_name",
629       MDNode::get(M->getContext(),
630                   MDString::get(M->getContext(), *MangledName)));
631   F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {}));
632   BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
633   IRBuilder<> B(BB);
634 
635   // Load the global symbol as a pointer to the check function.
636   Value *GuardFn;
637   if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf"))
638     GuardFn = GuardFnCFGlobal;
639   else
640     GuardFn = GuardFnGlobal;
641   LoadInst *GuardCheckLoad = B.CreateLoad(PtrTy, GuardFn);
642 
643   // Create new call instruction. The CFGuard check should always be a call,
644   // even if the original CallBase is an Invoke or CallBr instruction.
645   Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes());
646   CallInst *GuardCheck = B.CreateCall(
647       GuardFnType, GuardCheckLoad, {F, Thunk});
648 
649   // Ensure that the first argument is passed in the correct register.
650   GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
651 
652   SmallVector<Value *> Args(llvm::make_pointer_range(GuestExit->args()));
653   CallInst *Call = B.CreateCall(Arm64Ty, GuardCheck, Args);
654   Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
655 
656   if (Call->getType()->isVoidTy())
657     B.CreateRetVoid();
658   else
659     B.CreateRet(Call);
660 
661   auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
662   auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
663   if (SRetAttr.isValid() && !InRegAttr.isValid()) {
664     GuestExit->addParamAttr(0, SRetAttr);
665     Call->addParamAttr(0, SRetAttr);
666   }
667 
668   return GuestExit;
669 }
670 
671 Function *
672 AArch64Arm64ECCallLowering::buildPatchableThunk(GlobalAlias *UnmangledAlias,
673                                                 GlobalAlias *MangledAlias) {
674   llvm::raw_null_ostream NullThunkName;
675   FunctionType *Arm64Ty, *X64Ty;
676   Function *F = cast<Function>(MangledAlias->getAliasee());
677   SmallVector<ThunkArgTranslation> ArgTranslations;
678   getThunkType(F->getFunctionType(), F->getAttributes(),
679                Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
680                ArgTranslations);
681   std::string ThunkName(MangledAlias->getName());
682   if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
683     ThunkName.insert(ThunkName.find("@"), "$hybpatch_thunk");
684   } else {
685     ThunkName.append("$hybpatch_thunk");
686   }
687 
688   Function *GuestExit =
689       Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
690   GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
691   GuestExit->setSection(".wowthk$aa");
692   BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
693   IRBuilder<> B(BB);
694 
695   // Load the global symbol as a pointer to the check function.
696   LoadInst *DispatchLoad = B.CreateLoad(PtrTy, DispatchFnGlobal);
697 
698   // Create new dispatch call instruction.
699   Function *ExitThunk =
700       buildExitThunk(F->getFunctionType(), F->getAttributes());
701   CallInst *Dispatch =
702       B.CreateCall(DispatchFnType, DispatchLoad,
703                    {UnmangledAlias, ExitThunk, UnmangledAlias->getAliasee()});
704 
705   // Ensure that the first arguments are passed in the correct registers.
706   Dispatch->setCallingConv(CallingConv::CFGuard_Check);
707 
708   SmallVector<Value *> Args(llvm::make_pointer_range(GuestExit->args()));
709   CallInst *Call = B.CreateCall(Arm64Ty, Dispatch, Args);
710   Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
711 
712   if (Call->getType()->isVoidTy())
713     B.CreateRetVoid();
714   else
715     B.CreateRet(Call);
716 
717   auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
718   auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
719   if (SRetAttr.isValid() && !InRegAttr.isValid()) {
720     GuestExit->addParamAttr(0, SRetAttr);
721     Call->addParamAttr(0, SRetAttr);
722   }
723 
724   MangledAlias->setAliasee(GuestExit);
725   return GuestExit;
726 }
727 
728 // Lower an indirect call with inline code.
729 void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
730   assert(CB->getModule()->getTargetTriple().isOSWindows() &&
731          "Only applicable for Windows targets");
732 
733   IRBuilder<> B(CB);
734   Value *CalledOperand = CB->getCalledOperand();
735 
736   // If the indirect call is called within catchpad or cleanuppad,
737   // we need to copy "funclet" bundle of the call.
738   SmallVector<llvm::OperandBundleDef, 1> Bundles;
739   if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet))
740     Bundles.push_back(OperandBundleDef(*Bundle));
741 
742   // Load the global symbol as a pointer to the check function.
743   Value *GuardFn;
744   if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf"))
745     GuardFn = GuardFnCFGlobal;
746   else
747     GuardFn = GuardFnGlobal;
748   LoadInst *GuardCheckLoad = B.CreateLoad(PtrTy, GuardFn);
749 
750   // Create new call instruction. The CFGuard check should always be a call,
751   // even if the original CallBase is an Invoke or CallBr instruction.
752   Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes());
753   CallInst *GuardCheck =
754       B.CreateCall(GuardFnType, GuardCheckLoad, {CalledOperand, Thunk},
755                    Bundles);
756 
757   // Ensure that the first argument is passed in the correct register.
758   GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
759 
760   CB->setCalledOperand(GuardCheck);
761 }
762 
763 bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
764   if (!GenerateThunks)
765     return false;
766 
767   M = &Mod;
768 
769   // Check if this module has the cfguard flag and read its value.
770   if (auto *MD =
771           mdconst::extract_or_null<ConstantInt>(M->getModuleFlag("cfguard")))
772     cfguard_module_flag = MD->getZExtValue();
773 
774   PtrTy = PointerType::getUnqual(M->getContext());
775   I64Ty = Type::getInt64Ty(M->getContext());
776   VoidTy = Type::getVoidTy(M->getContext());
777 
778   GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
779   DispatchFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy, PtrTy}, false);
780   GuardFnCFGlobal = M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", PtrTy);
781   GuardFnGlobal = M->getOrInsertGlobal("__os_arm64x_check_icall", PtrTy);
782   DispatchFnGlobal = M->getOrInsertGlobal("__os_arm64x_dispatch_call", PtrTy);
783 
784   // Mangle names of function aliases and add the alias name to
785   // arm64ec_unmangled_name metadata to ensure a weak anti-dependency symbol is
786   // emitted for the alias as well. Do this early, before handling
787   // hybrid_patchable functions, to avoid mangling their aliases.
788   for (GlobalAlias &A : Mod.aliases()) {
789     auto F = dyn_cast_or_null<Function>(A.getAliaseeObject());
790     if (!F)
791       continue;
792     if (std::optional<std::string> MangledName =
793             getArm64ECMangledFunctionName(A.getName().str())) {
794       F->addMetadata("arm64ec_unmangled_name",
795                      *MDNode::get(M->getContext(),
796                                   MDString::get(M->getContext(), A.getName())));
797       A.setName(MangledName.value());
798     }
799   }
800 
801   DenseMap<GlobalAlias *, GlobalAlias *> FnsMap;
802   SetVector<GlobalAlias *> PatchableFns;
803 
804   for (Function &F : Mod) {
805     if (F.hasPersonalityFn()) {
806       GlobalValue *PersFn =
807           cast<GlobalValue>(F.getPersonalityFn()->stripPointerCasts());
808       if (PersFn->getValueType() && PersFn->getValueType()->isFunctionTy()) {
809         if (std::optional<std::string> MangledName =
810                 getArm64ECMangledFunctionName(PersFn->getName().str())) {
811           PersFn->setName(MangledName.value());
812         }
813       }
814     }
815 
816     if (!F.hasFnAttribute(Attribute::HybridPatchable) || F.isDeclaration() ||
817         F.hasLocalLinkage() ||
818         F.getName().ends_with(HybridPatchableTargetSuffix))
819       continue;
820 
821     // Rename hybrid patchable functions and change callers to use a global
822     // alias instead.
823     if (std::optional<std::string> MangledName =
824             getArm64ECMangledFunctionName(F.getName().str())) {
825       std::string OrigName(F.getName());
826       F.setName(MangledName.value() + HybridPatchableTargetSuffix);
827 
828       // The unmangled symbol is a weak alias to an undefined symbol with the
829       // "EXP+" prefix. This undefined symbol is resolved by the linker by
830       // creating an x86 thunk that jumps back to the actual EC target. Since we
831       // can't represent that in IR, we create an alias to the target instead.
832       // The "EXP+" symbol is set as metadata, which is then used by
833       // emitGlobalAlias to emit the right alias.
834       auto *A =
835           GlobalAlias::create(GlobalValue::LinkOnceODRLinkage, OrigName, &F);
836       auto *AM = GlobalAlias::create(GlobalValue::LinkOnceODRLinkage,
837                                      MangledName.value(), &F);
838       F.replaceUsesWithIf(AM,
839                           [](Use &U) { return isa<GlobalAlias>(U.getUser()); });
840       F.replaceAllUsesWith(A);
841       F.setMetadata("arm64ec_exp_name",
842                     MDNode::get(M->getContext(),
843                                 MDString::get(M->getContext(),
844                                               "EXP+" + MangledName.value())));
845       A->setAliasee(&F);
846       AM->setAliasee(&F);
847 
848       if (F.hasDLLExportStorageClass()) {
849         A->setDLLStorageClass(GlobalValue::DLLExportStorageClass);
850         F.setDLLStorageClass(GlobalValue::DefaultStorageClass);
851       }
852 
853       FnsMap[A] = AM;
854       PatchableFns.insert(A);
855     }
856   }
857 
858   SetVector<GlobalValue *> DirectCalledFns;
859   for (Function &F : Mod)
860     if (!F.isDeclaration() &&
861         F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
862         F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)
863       processFunction(F, DirectCalledFns, FnsMap);
864 
865   struct ThunkInfo {
866     Constant *Src;
867     Constant *Dst;
868     Arm64ECThunkType Kind;
869   };
870   SmallVector<ThunkInfo> ThunkMapping;
871   for (Function &F : Mod) {
872     if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) &&
873         F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
874         F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) {
875       if (!F.hasComdat())
876         F.setComdat(Mod.getOrInsertComdat(F.getName()));
877       ThunkMapping.push_back(
878           {&F, buildEntryThunk(&F), Arm64ECThunkType::Entry});
879     }
880   }
881   for (GlobalValue *O : DirectCalledFns) {
882     auto GA = dyn_cast<GlobalAlias>(O);
883     auto F = dyn_cast<Function>(GA ? GA->getAliasee() : O);
884     ThunkMapping.push_back(
885         {O, buildExitThunk(F->getFunctionType(), F->getAttributes()),
886          Arm64ECThunkType::Exit});
887     if (!GA && !F->hasDLLImportStorageClass())
888       ThunkMapping.push_back(
889           {buildGuestExitThunk(F), F, Arm64ECThunkType::GuestExit});
890   }
891   for (GlobalAlias *A : PatchableFns) {
892     Function *Thunk = buildPatchableThunk(A, FnsMap[A]);
893     ThunkMapping.push_back({Thunk, A, Arm64ECThunkType::GuestExit});
894   }
895 
896   if (!ThunkMapping.empty()) {
897     SmallVector<Constant *> ThunkMappingArrayElems;
898     for (ThunkInfo &Thunk : ThunkMapping) {
899       ThunkMappingArrayElems.push_back(ConstantStruct::getAnon(
900           {Thunk.Src, Thunk.Dst,
901            ConstantInt::get(M->getContext(), APInt(32, uint8_t(Thunk.Kind)))}));
902     }
903     Constant *ThunkMappingArray = ConstantArray::get(
904         llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(),
905                              ThunkMappingArrayElems.size()),
906         ThunkMappingArrayElems);
907     new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false,
908                        GlobalValue::ExternalLinkage, ThunkMappingArray,
909                        "llvm.arm64ec.symbolmap");
910   }
911 
912   return true;
913 }
914 
915 bool AArch64Arm64ECCallLowering::processFunction(
916     Function &F, SetVector<GlobalValue *> &DirectCalledFns,
917     DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap) {
918   SmallVector<CallBase *, 8> IndirectCalls;
919 
920   // For ARM64EC targets, a function definition's name is mangled differently
921   // from the normal symbol. We currently have no representation of this sort
922   // of symbol in IR, so we change the name to the mangled name, then store
923   // the unmangled name as metadata.  Later passes that need the unmangled
924   // name (emitting the definition) can grab it from the metadata.
925   //
926   // FIXME: Handle functions with weak linkage?
927   if (!F.hasLocalLinkage() || F.hasAddressTaken()) {
928     if (std::optional<std::string> MangledName =
929             getArm64ECMangledFunctionName(F.getName().str())) {
930       F.addMetadata("arm64ec_unmangled_name",
931                     *MDNode::get(M->getContext(),
932                                  MDString::get(M->getContext(), F.getName())));
933       if (F.hasComdat() && F.getComdat()->getName() == F.getName()) {
934         Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value());
935         SmallVector<GlobalObject *> ComdatUsers =
936             to_vector(F.getComdat()->getUsers());
937         for (GlobalObject *User : ComdatUsers)
938           User->setComdat(MangledComdat);
939       }
940       F.setName(MangledName.value());
941     }
942   }
943 
944   // Iterate over the instructions to find all indirect call/invoke/callbr
945   // instructions. Make a separate list of pointers to indirect
946   // call/invoke/callbr instructions because the original instructions will be
947   // deleted as the checks are added.
948   for (BasicBlock &BB : F) {
949     for (Instruction &I : BB) {
950       auto *CB = dyn_cast<CallBase>(&I);
951       if (!CB || CB->getCallingConv() == CallingConv::ARM64EC_Thunk_X64 ||
952           CB->isInlineAsm())
953         continue;
954 
955       // We need to instrument any call that isn't directly calling an
956       // ARM64 function.
957       //
958       // FIXME: getCalledFunction() fails if there's a bitcast (e.g.
959       // unprototyped functions in C)
960       if (Function *F = CB->getCalledFunction()) {
961         if (!LowerDirectToIndirect || F->hasLocalLinkage() ||
962             F->isIntrinsic() || !F->isDeclaration())
963           continue;
964 
965         DirectCalledFns.insert(F);
966         continue;
967       }
968 
969       // Use mangled global alias for direct calls to patchable functions.
970       if (GlobalAlias *A = dyn_cast<GlobalAlias>(CB->getCalledOperand())) {
971         auto I = FnsMap.find(A);
972         if (I != FnsMap.end()) {
973           CB->setCalledOperand(I->second);
974           DirectCalledFns.insert(I->first);
975           continue;
976         }
977       }
978 
979       IndirectCalls.push_back(CB);
980       ++Arm64ECCallsLowered;
981     }
982   }
983 
984   if (IndirectCalls.empty())
985     return false;
986 
987   for (CallBase *CB : IndirectCalls)
988     lowerCall(CB);
989 
990   return true;
991 }
992 
993 char AArch64Arm64ECCallLowering::ID = 0;
994 INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering",
995                 "AArch64Arm64ECCallLowering", false, false)
996 
997 ModulePass *llvm::createAArch64Arm64ECCallLoweringPass() {
998   return new AArch64Arm64ECCallLowering;
999 }
1000