xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp (revision b64c5a0ace59af62eff52bfe110a521dc73c937b)
1 //===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file contains the IR transform to lower external or indirect calls for
11 /// the ARM64EC calling convention. Such calls must go through the runtime, so
12 /// we can translate the calling convention for calls into the emulator.
13 ///
14 /// This subsumes Control Flow Guard handling.
15 ///
16 //===----------------------------------------------------------------------===//
17 
18 #include "AArch64.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/IR/CallingConv.h"
24 #include "llvm/IR/GlobalAlias.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Mangler.h"
28 #include "llvm/IR/Module.h"
29 #include "llvm/InitializePasses.h"
30 #include "llvm/Object/COFF.h"
31 #include "llvm/Pass.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/TargetParser/Triple.h"
34 
35 using namespace llvm;
36 using namespace llvm::COFF;
37 
38 using OperandBundleDef = OperandBundleDefT<Value *>;
39 
40 #define DEBUG_TYPE "arm64eccalllowering"
41 
42 STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered");
43 
44 static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect",
45                                            cl::Hidden, cl::init(true));
46 static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden,
47                                     cl::init(true));
48 
49 namespace {
50 
51 enum ThunkArgTranslation : uint8_t {
52   Direct,
53   Bitcast,
54   PointerIndirection,
55 };
56 
57 struct ThunkArgInfo {
58   Type *Arm64Ty;
59   Type *X64Ty;
60   ThunkArgTranslation Translation;
61 };
62 
63 class AArch64Arm64ECCallLowering : public ModulePass {
64 public:
65   static char ID;
66   AArch64Arm64ECCallLowering() : ModulePass(ID) {
67     initializeAArch64Arm64ECCallLoweringPass(*PassRegistry::getPassRegistry());
68   }
69 
70   Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs);
71   Function *buildEntryThunk(Function *F);
72   void lowerCall(CallBase *CB);
73   Function *buildGuestExitThunk(Function *F);
74   Function *buildPatchableThunk(GlobalAlias *UnmangledAlias,
75                                 GlobalAlias *MangledAlias);
76   bool processFunction(Function &F, SetVector<GlobalValue *> &DirectCalledFns,
77                        DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap);
78   bool runOnModule(Module &M) override;
79 
80 private:
81   int cfguard_module_flag = 0;
82   FunctionType *GuardFnType = nullptr;
83   PointerType *GuardFnPtrType = nullptr;
84   FunctionType *DispatchFnType = nullptr;
85   PointerType *DispatchFnPtrType = nullptr;
86   Constant *GuardFnCFGlobal = nullptr;
87   Constant *GuardFnGlobal = nullptr;
88   Constant *DispatchFnGlobal = nullptr;
89   Module *M = nullptr;
90 
91   Type *PtrTy;
92   Type *I64Ty;
93   Type *VoidTy;
94 
95   void getThunkType(FunctionType *FT, AttributeList AttrList,
96                     Arm64ECThunkType TT, raw_ostream &Out,
97                     FunctionType *&Arm64Ty, FunctionType *&X64Ty,
98                     SmallVector<ThunkArgTranslation> &ArgTranslations);
99   void getThunkRetType(FunctionType *FT, AttributeList AttrList,
100                        raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy,
101                        SmallVectorImpl<Type *> &Arm64ArgTypes,
102                        SmallVectorImpl<Type *> &X64ArgTypes,
103                        SmallVector<ThunkArgTranslation> &ArgTranslations,
104                        bool &HasSretPtr);
105   void getThunkArgTypes(FunctionType *FT, AttributeList AttrList,
106                         Arm64ECThunkType TT, raw_ostream &Out,
107                         SmallVectorImpl<Type *> &Arm64ArgTypes,
108                         SmallVectorImpl<Type *> &X64ArgTypes,
109                         SmallVectorImpl<ThunkArgTranslation> &ArgTranslations,
110                         bool HasSretPtr);
111   ThunkArgInfo canonicalizeThunkType(Type *T, Align Alignment, bool Ret,
112                                      uint64_t ArgSizeBytes, raw_ostream &Out);
113 };
114 
115 } // end anonymous namespace
116 
117 void AArch64Arm64ECCallLowering::getThunkType(
118     FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
119     raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty,
120     SmallVector<ThunkArgTranslation> &ArgTranslations) {
121   Out << (TT == Arm64ECThunkType::Entry ? "$ientry_thunk$cdecl$"
122                                         : "$iexit_thunk$cdecl$");
123 
124   Type *Arm64RetTy;
125   Type *X64RetTy;
126 
127   SmallVector<Type *> Arm64ArgTypes;
128   SmallVector<Type *> X64ArgTypes;
129 
130   // The first argument to a thunk is the called function, stored in x9.
131   // For exit thunks, we pass the called function down to the emulator;
132   // for entry/guest exit thunks, we just call the Arm64 function directly.
133   if (TT == Arm64ECThunkType::Exit)
134     Arm64ArgTypes.push_back(PtrTy);
135   X64ArgTypes.push_back(PtrTy);
136 
137   bool HasSretPtr = false;
138   getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
139                   X64ArgTypes, ArgTranslations, HasSretPtr);
140 
141   getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,
142                    ArgTranslations, HasSretPtr);
143 
144   Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false);
145 
146   X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false);
147 }
148 
149 void AArch64Arm64ECCallLowering::getThunkArgTypes(
150     FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
151     raw_ostream &Out, SmallVectorImpl<Type *> &Arm64ArgTypes,
152     SmallVectorImpl<Type *> &X64ArgTypes,
153     SmallVectorImpl<ThunkArgTranslation> &ArgTranslations, bool HasSretPtr) {
154 
155   Out << "$";
156   if (FT->isVarArg()) {
157     // We treat the variadic function's thunk as a normal function
158     // with the following type on the ARM side:
159     //   rettype exitthunk(
160     //     ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5)
161     //
162     // that can coverage all types of variadic function.
163     // x9 is similar to normal exit thunk, store the called function.
164     // x0-x3 is the arguments be stored in registers.
165     // x4 is the address of the arguments on the stack.
166     // x5 is the size of the arguments on the stack.
167     //
168     // On the x64 side, it's the same except that x5 isn't set.
169     //
170     // If both the ARM and X64 sides are sret, there are only three
171     // arguments in registers.
172     //
173     // If the X64 side is sret, but the ARM side isn't, we pass an extra value
174     // to/from the X64 side, and let SelectionDAG transform it into a memory
175     // location.
176     Out << "varargs";
177 
178     // x0-x3
179     for (int i = HasSretPtr ? 1 : 0; i < 4; i++) {
180       Arm64ArgTypes.push_back(I64Ty);
181       X64ArgTypes.push_back(I64Ty);
182       ArgTranslations.push_back(ThunkArgTranslation::Direct);
183     }
184 
185     // x4
186     Arm64ArgTypes.push_back(PtrTy);
187     X64ArgTypes.push_back(PtrTy);
188     ArgTranslations.push_back(ThunkArgTranslation::Direct);
189     // x5
190     Arm64ArgTypes.push_back(I64Ty);
191     if (TT != Arm64ECThunkType::Entry) {
192       // FIXME: x5 isn't actually used by the x64 side; revisit once we
193       // have proper isel for varargs
194       X64ArgTypes.push_back(I64Ty);
195       ArgTranslations.push_back(ThunkArgTranslation::Direct);
196     }
197     return;
198   }
199 
200   unsigned I = 0;
201   if (HasSretPtr)
202     I++;
203 
204   if (I == FT->getNumParams()) {
205     Out << "v";
206     return;
207   }
208 
209   for (unsigned E = FT->getNumParams(); I != E; ++I) {
210 #if 0
211     // FIXME: Need more information about argument size; see
212     // https://reviews.llvm.org/D132926
213     uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I);
214     Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne();
215 #else
216     uint64_t ArgSizeBytes = 0;
217     Align ParamAlign = Align();
218 #endif
219     auto [Arm64Ty, X64Ty, ArgTranslation] =
220         canonicalizeThunkType(FT->getParamType(I), ParamAlign,
221                               /*Ret*/ false, ArgSizeBytes, Out);
222     Arm64ArgTypes.push_back(Arm64Ty);
223     X64ArgTypes.push_back(X64Ty);
224     ArgTranslations.push_back(ArgTranslation);
225   }
226 }
227 
228 void AArch64Arm64ECCallLowering::getThunkRetType(
229     FunctionType *FT, AttributeList AttrList, raw_ostream &Out,
230     Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes,
231     SmallVectorImpl<Type *> &X64ArgTypes,
232     SmallVector<ThunkArgTranslation> &ArgTranslations, bool &HasSretPtr) {
233   Type *T = FT->getReturnType();
234 #if 0
235   // FIXME: Need more information about argument size; see
236   // https://reviews.llvm.org/D132926
237   uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes();
238 #else
239   int64_t ArgSizeBytes = 0;
240 #endif
241   if (T->isVoidTy()) {
242     if (FT->getNumParams()) {
243       Attribute SRetAttr0 = AttrList.getParamAttr(0, Attribute::StructRet);
244       Attribute InRegAttr0 = AttrList.getParamAttr(0, Attribute::InReg);
245       Attribute SRetAttr1, InRegAttr1;
246       if (FT->getNumParams() > 1) {
247         // Also check the second parameter (for class methods, the first
248         // parameter is "this", and the second parameter is the sret pointer.)
249         // It doesn't matter which one is sret.
250         SRetAttr1 = AttrList.getParamAttr(1, Attribute::StructRet);
251         InRegAttr1 = AttrList.getParamAttr(1, Attribute::InReg);
252       }
253       if ((SRetAttr0.isValid() && InRegAttr0.isValid()) ||
254           (SRetAttr1.isValid() && InRegAttr1.isValid())) {
255         // sret+inreg indicates a call that returns a C++ class value. This is
256         // actually equivalent to just passing and returning a void* pointer
257         // as the first or second argument. Translate it that way, instead of
258         // trying to model "inreg" in the thunk's calling convention; this
259         // simplfies the rest of the code, and matches MSVC mangling.
260         Out << "i8";
261         Arm64RetTy = I64Ty;
262         X64RetTy = I64Ty;
263         return;
264       }
265       if (SRetAttr0.isValid()) {
266         // FIXME: Sanity-check the sret type; if it's an integer or pointer,
267         // we'll get screwy mangling/codegen.
268         // FIXME: For large struct types, mangle as an integer argument and
269         // integer return, so we can reuse more thunks, instead of "m" syntax.
270         // (MSVC mangles this case as an integer return with no argument, but
271         // that's a miscompile.)
272         Type *SRetType = SRetAttr0.getValueAsType();
273         Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne();
274         canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes,
275                               Out);
276         Arm64RetTy = VoidTy;
277         X64RetTy = VoidTy;
278         Arm64ArgTypes.push_back(FT->getParamType(0));
279         X64ArgTypes.push_back(FT->getParamType(0));
280         ArgTranslations.push_back(ThunkArgTranslation::Direct);
281         HasSretPtr = true;
282         return;
283       }
284     }
285 
286     Out << "v";
287     Arm64RetTy = VoidTy;
288     X64RetTy = VoidTy;
289     return;
290   }
291 
292   auto info =
293       canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out);
294   Arm64RetTy = info.Arm64Ty;
295   X64RetTy = info.X64Ty;
296   if (X64RetTy->isPointerTy()) {
297     // If the X64 type is canonicalized to a pointer, that means it's
298     // passed/returned indirectly. For a return value, that means it's an
299     // sret pointer.
300     X64ArgTypes.push_back(X64RetTy);
301     X64RetTy = VoidTy;
302   }
303 }
304 
305 ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType(
306     Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes,
307     raw_ostream &Out) {
308 
309   auto direct = [](Type *T) {
310     return ThunkArgInfo{T, T, ThunkArgTranslation::Direct};
311   };
312 
313   auto bitcast = [this](Type *Arm64Ty, uint64_t SizeInBytes) {
314     return ThunkArgInfo{Arm64Ty,
315                         llvm::Type::getIntNTy(M->getContext(), SizeInBytes * 8),
316                         ThunkArgTranslation::Bitcast};
317   };
318 
319   auto pointerIndirection = [this](Type *Arm64Ty) {
320     return ThunkArgInfo{Arm64Ty, PtrTy,
321                         ThunkArgTranslation::PointerIndirection};
322   };
323 
324   if (T->isFloatTy()) {
325     Out << "f";
326     return direct(T);
327   }
328 
329   if (T->isDoubleTy()) {
330     Out << "d";
331     return direct(T);
332   }
333 
334   if (T->isFloatingPointTy()) {
335     report_fatal_error(
336         "Only 32 and 64 bit floating points are supported for ARM64EC thunks");
337   }
338 
339   auto &DL = M->getDataLayout();
340 
341   if (auto *StructTy = dyn_cast<StructType>(T))
342     if (StructTy->getNumElements() == 1)
343       T = StructTy->getElementType(0);
344 
345   if (T->isArrayTy()) {
346     Type *ElementTy = T->getArrayElementType();
347     uint64_t ElementCnt = T->getArrayNumElements();
348     uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8;
349     uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes;
350     if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) {
351       Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes;
352       if (Alignment.value() >= 16 && !Ret)
353         Out << "a" << Alignment.value();
354       if (TotalSizeBytes <= 8) {
355         // Arm64 returns small structs of float/double in float registers;
356         // X64 uses RAX.
357         return bitcast(T, TotalSizeBytes);
358       } else {
359         // Struct is passed directly on Arm64, but indirectly on X64.
360         return pointerIndirection(T);
361       }
362     } else if (T->isFloatingPointTy()) {
363       report_fatal_error("Only 32 and 64 bit floating points are supported for "
364                          "ARM64EC thunks");
365     }
366   }
367 
368   if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) {
369     Out << "i8";
370     return direct(I64Ty);
371   }
372 
373   unsigned TypeSize = ArgSizeBytes;
374   if (TypeSize == 0)
375     TypeSize = DL.getTypeSizeInBits(T) / 8;
376   Out << "m";
377   if (TypeSize != 4)
378     Out << TypeSize;
379   if (Alignment.value() >= 16 && !Ret)
380     Out << "a" << Alignment.value();
381   // FIXME: Try to canonicalize Arm64Ty more thoroughly?
382   if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) {
383     // Pass directly in an integer register
384     return bitcast(T, TypeSize);
385   } else {
386     // Passed directly on Arm64, but indirectly on X64.
387     return pointerIndirection(T);
388   }
389 }
390 
391 // This function builds the "exit thunk", a function which translates
392 // arguments and return values when calling x64 code from AArch64 code.
393 Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
394                                                      AttributeList Attrs) {
395   SmallString<256> ExitThunkName;
396   llvm::raw_svector_ostream ExitThunkStream(ExitThunkName);
397   FunctionType *Arm64Ty, *X64Ty;
398   SmallVector<ThunkArgTranslation> ArgTranslations;
399   getThunkType(FT, Attrs, Arm64ECThunkType::Exit, ExitThunkStream, Arm64Ty,
400                X64Ty, ArgTranslations);
401   if (Function *F = M->getFunction(ExitThunkName))
402     return F;
403 
404   Function *F = Function::Create(Arm64Ty, GlobalValue::LinkOnceODRLinkage, 0,
405                                  ExitThunkName, M);
406   F->setCallingConv(CallingConv::ARM64EC_Thunk_Native);
407   F->setSection(".wowthk$aa");
408   F->setComdat(M->getOrInsertComdat(ExitThunkName));
409   // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
410   F->addFnAttr("frame-pointer", "all");
411   // Only copy sret from the first argument. For C++ instance methods, clang can
412   // stick an sret marking on a later argument, but it doesn't actually affect
413   // the ABI, so we can omit it. This avoids triggering a verifier assertion.
414   if (FT->getNumParams()) {
415     auto SRet = Attrs.getParamAttr(0, Attribute::StructRet);
416     auto InReg = Attrs.getParamAttr(0, Attribute::InReg);
417     if (SRet.isValid() && !InReg.isValid())
418       F->addParamAttr(1, SRet);
419   }
420   // FIXME: Copy anything other than sret?  Shouldn't be necessary for normal
421   // C ABI, but might show up in other cases.
422   BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F);
423   IRBuilder<> IRB(BB);
424   Value *CalleePtr =
425       M->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy);
426   Value *Callee = IRB.CreateLoad(PtrTy, CalleePtr);
427   auto &DL = M->getDataLayout();
428   SmallVector<Value *> Args;
429 
430   // Pass the called function in x9.
431   auto X64TyOffset = 1;
432   Args.push_back(F->arg_begin());
433 
434   Type *RetTy = Arm64Ty->getReturnType();
435   if (RetTy != X64Ty->getReturnType()) {
436     // If the return type is an array or struct, translate it. Values of size
437     // 8 or less go into RAX; bigger values go into memory, and we pass a
438     // pointer.
439     if (DL.getTypeStoreSize(RetTy) > 8) {
440       Args.push_back(IRB.CreateAlloca(RetTy));
441       X64TyOffset++;
442     }
443   }
444 
445   for (auto [Arg, X64ArgType, ArgTranslation] : llvm::zip_equal(
446            make_range(F->arg_begin() + 1, F->arg_end()),
447            make_range(X64Ty->param_begin() + X64TyOffset, X64Ty->param_end()),
448            ArgTranslations)) {
449     // Translate arguments from AArch64 calling convention to x86 calling
450     // convention.
451     //
452     // For simple types, we don't need to do any translation: they're
453     // represented the same way. (Implicit sign extension is not part of
454     // either convention.)
455     //
456     // The big thing we have to worry about is struct types... but
457     // fortunately AArch64 clang is pretty friendly here: the cases that need
458     // translation are always passed as a struct or array. (If we run into
459     // some cases where this doesn't work, we can teach clang to mark it up
460     // with an attribute.)
461     //
462     // The first argument is the called function, stored in x9.
463     if (ArgTranslation != ThunkArgTranslation::Direct) {
464       Value *Mem = IRB.CreateAlloca(Arg.getType());
465       IRB.CreateStore(&Arg, Mem);
466       if (ArgTranslation == ThunkArgTranslation::Bitcast) {
467         Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType()));
468         Args.push_back(IRB.CreateLoad(IntTy, IRB.CreateBitCast(Mem, PtrTy)));
469       } else {
470         assert(ArgTranslation == ThunkArgTranslation::PointerIndirection);
471         Args.push_back(Mem);
472       }
473     } else {
474       Args.push_back(&Arg);
475     }
476     assert(Args.back()->getType() == X64ArgType);
477   }
478   // FIXME: Transfer necessary attributes? sret? anything else?
479 
480   Callee = IRB.CreateBitCast(Callee, PtrTy);
481   CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args);
482   Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
483 
484   Value *RetVal = Call;
485   if (RetTy != X64Ty->getReturnType()) {
486     // If we rewrote the return type earlier, convert the return value to
487     // the proper type.
488     if (DL.getTypeStoreSize(RetTy) > 8) {
489       RetVal = IRB.CreateLoad(RetTy, Args[1]);
490     } else {
491       Value *CastAlloca = IRB.CreateAlloca(RetTy);
492       IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
493       RetVal = IRB.CreateLoad(RetTy, CastAlloca);
494     }
495   }
496 
497   if (RetTy->isVoidTy())
498     IRB.CreateRetVoid();
499   else
500     IRB.CreateRet(RetVal);
501   return F;
502 }
503 
504 // This function builds the "entry thunk", a function which translates
505 // arguments and return values when calling AArch64 code from x64 code.
506 Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
507   SmallString<256> EntryThunkName;
508   llvm::raw_svector_ostream EntryThunkStream(EntryThunkName);
509   FunctionType *Arm64Ty, *X64Ty;
510   SmallVector<ThunkArgTranslation> ArgTranslations;
511   getThunkType(F->getFunctionType(), F->getAttributes(),
512                Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty,
513                ArgTranslations);
514   if (Function *F = M->getFunction(EntryThunkName))
515     return F;
516 
517   Function *Thunk = Function::Create(X64Ty, GlobalValue::LinkOnceODRLinkage, 0,
518                                      EntryThunkName, M);
519   Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
520   Thunk->setSection(".wowthk$aa");
521   Thunk->setComdat(M->getOrInsertComdat(EntryThunkName));
522   // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
523   Thunk->addFnAttr("frame-pointer", "all");
524 
525   BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk);
526   IRBuilder<> IRB(BB);
527 
528   Type *RetTy = Arm64Ty->getReturnType();
529   Type *X64RetType = X64Ty->getReturnType();
530 
531   bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy();
532   unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
533   unsigned PassthroughArgSize =
534       (F->isVarArg() ? 5 : Thunk->arg_size()) - ThunkArgOffset;
535   assert(ArgTranslations.size() == (F->isVarArg() ? 5 : PassthroughArgSize));
536 
537   // Translate arguments to call.
538   SmallVector<Value *> Args;
539   for (unsigned i = 0; i != PassthroughArgSize; ++i) {
540     Value *Arg = Thunk->getArg(i + ThunkArgOffset);
541     Type *ArgTy = Arm64Ty->getParamType(i);
542     ThunkArgTranslation ArgTranslation = ArgTranslations[i];
543     if (ArgTranslation != ThunkArgTranslation::Direct) {
544       // Translate array/struct arguments to the expected type.
545       if (ArgTranslation == ThunkArgTranslation::Bitcast) {
546         Value *CastAlloca = IRB.CreateAlloca(ArgTy);
547         IRB.CreateStore(Arg, IRB.CreateBitCast(CastAlloca, PtrTy));
548         Arg = IRB.CreateLoad(ArgTy, CastAlloca);
549       } else {
550         assert(ArgTranslation == ThunkArgTranslation::PointerIndirection);
551         Arg = IRB.CreateLoad(ArgTy, IRB.CreateBitCast(Arg, PtrTy));
552       }
553     }
554     assert(Arg->getType() == ArgTy);
555     Args.push_back(Arg);
556   }
557 
558   if (F->isVarArg()) {
559     // The 5th argument to variadic entry thunks is used to model the x64 sp
560     // which is passed to the thunk in x4, this can be passed to the callee as
561     // the variadic argument start address after skipping over the 32 byte
562     // shadow store.
563 
564     // The EC thunk CC will assign any argument marked as InReg to x4.
565     Thunk->addParamAttr(5, Attribute::InReg);
566     Value *Arg = Thunk->getArg(5);
567     Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20));
568     Args.push_back(Arg);
569 
570     // Pass in a zero variadic argument size (in x5).
571     Args.push_back(IRB.getInt64(0));
572   }
573 
574   // Call the function passed to the thunk.
575   Value *Callee = Thunk->getArg(0);
576   Callee = IRB.CreateBitCast(Callee, PtrTy);
577   CallInst *Call = IRB.CreateCall(Arm64Ty, Callee, Args);
578 
579   auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
580   auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
581   if (SRetAttr.isValid() && !InRegAttr.isValid()) {
582     Thunk->addParamAttr(1, SRetAttr);
583     Call->addParamAttr(0, SRetAttr);
584   }
585 
586   Value *RetVal = Call;
587   if (TransformDirectToSRet) {
588     IRB.CreateStore(RetVal, IRB.CreateBitCast(Thunk->getArg(1), PtrTy));
589   } else if (X64RetType != RetTy) {
590     Value *CastAlloca = IRB.CreateAlloca(X64RetType);
591     IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
592     RetVal = IRB.CreateLoad(X64RetType, CastAlloca);
593   }
594 
595   // Return to the caller.  Note that the isel has code to translate this
596   // "ret" to a tail call to __os_arm64x_dispatch_ret.  (Alternatively, we
597   // could emit a tail call here, but that would require a dedicated calling
598   // convention, which seems more complicated overall.)
599   if (X64RetType->isVoidTy())
600     IRB.CreateRetVoid();
601   else
602     IRB.CreateRet(RetVal);
603 
604   return Thunk;
605 }
606 
607 // Builds the "guest exit thunk", a helper to call a function which may or may
608 // not be an exit thunk. (We optimistically assume non-dllimport function
609 // declarations refer to functions defined in AArch64 code; if the linker
610 // can't prove that, we use this routine instead.)
611 Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
612   llvm::raw_null_ostream NullThunkName;
613   FunctionType *Arm64Ty, *X64Ty;
614   SmallVector<ThunkArgTranslation> ArgTranslations;
615   getThunkType(F->getFunctionType(), F->getAttributes(),
616                Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
617                ArgTranslations);
618   auto MangledName = getArm64ECMangledFunctionName(F->getName().str());
619   assert(MangledName && "Can't guest exit to function that's already native");
620   std::string ThunkName = *MangledName;
621   if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
622     ThunkName.insert(ThunkName.find("@"), "$exit_thunk");
623   } else {
624     ThunkName.append("$exit_thunk");
625   }
626   Function *GuestExit =
627       Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
628   GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
629   GuestExit->setSection(".wowthk$aa");
630   GuestExit->setMetadata(
631       "arm64ec_unmangled_name",
632       MDNode::get(M->getContext(),
633                   MDString::get(M->getContext(), F->getName())));
634   GuestExit->setMetadata(
635       "arm64ec_ecmangled_name",
636       MDNode::get(M->getContext(),
637                   MDString::get(M->getContext(), *MangledName)));
638   F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {}));
639   BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
640   IRBuilder<> B(BB);
641 
642   // Load the global symbol as a pointer to the check function.
643   Value *GuardFn;
644   if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf"))
645     GuardFn = GuardFnCFGlobal;
646   else
647     GuardFn = GuardFnGlobal;
648   LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
649 
650   // Create new call instruction. The CFGuard check should always be a call,
651   // even if the original CallBase is an Invoke or CallBr instruction.
652   Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes());
653   CallInst *GuardCheck = B.CreateCall(
654       GuardFnType, GuardCheckLoad,
655       {B.CreateBitCast(F, B.getPtrTy()), B.CreateBitCast(Thunk, B.getPtrTy())});
656 
657   // Ensure that the first argument is passed in the correct register.
658   GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
659 
660   Value *GuardRetVal = B.CreateBitCast(GuardCheck, PtrTy);
661   SmallVector<Value *> Args;
662   for (Argument &Arg : GuestExit->args())
663     Args.push_back(&Arg);
664   CallInst *Call = B.CreateCall(Arm64Ty, GuardRetVal, Args);
665   Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
666 
667   if (Call->getType()->isVoidTy())
668     B.CreateRetVoid();
669   else
670     B.CreateRet(Call);
671 
672   auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
673   auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
674   if (SRetAttr.isValid() && !InRegAttr.isValid()) {
675     GuestExit->addParamAttr(0, SRetAttr);
676     Call->addParamAttr(0, SRetAttr);
677   }
678 
679   return GuestExit;
680 }
681 
682 Function *
683 AArch64Arm64ECCallLowering::buildPatchableThunk(GlobalAlias *UnmangledAlias,
684                                                 GlobalAlias *MangledAlias) {
685   llvm::raw_null_ostream NullThunkName;
686   FunctionType *Arm64Ty, *X64Ty;
687   Function *F = cast<Function>(MangledAlias->getAliasee());
688   SmallVector<ThunkArgTranslation> ArgTranslations;
689   getThunkType(F->getFunctionType(), F->getAttributes(),
690                Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
691                ArgTranslations);
692   std::string ThunkName(MangledAlias->getName());
693   if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
694     ThunkName.insert(ThunkName.find("@"), "$hybpatch_thunk");
695   } else {
696     ThunkName.append("$hybpatch_thunk");
697   }
698 
699   Function *GuestExit =
700       Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
701   GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
702   GuestExit->setSection(".wowthk$aa");
703   BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
704   IRBuilder<> B(BB);
705 
706   // Load the global symbol as a pointer to the check function.
707   LoadInst *DispatchLoad = B.CreateLoad(DispatchFnPtrType, DispatchFnGlobal);
708 
709   // Create new dispatch call instruction.
710   Function *ExitThunk =
711       buildExitThunk(F->getFunctionType(), F->getAttributes());
712   CallInst *Dispatch =
713       B.CreateCall(DispatchFnType, DispatchLoad,
714                    {UnmangledAlias, ExitThunk, UnmangledAlias->getAliasee()});
715 
716   // Ensure that the first arguments are passed in the correct registers.
717   Dispatch->setCallingConv(CallingConv::CFGuard_Check);
718 
719   Value *DispatchRetVal = B.CreateBitCast(Dispatch, PtrTy);
720   SmallVector<Value *> Args;
721   for (Argument &Arg : GuestExit->args())
722     Args.push_back(&Arg);
723   CallInst *Call = B.CreateCall(Arm64Ty, DispatchRetVal, Args);
724   Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
725 
726   if (Call->getType()->isVoidTy())
727     B.CreateRetVoid();
728   else
729     B.CreateRet(Call);
730 
731   auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
732   auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
733   if (SRetAttr.isValid() && !InRegAttr.isValid()) {
734     GuestExit->addParamAttr(0, SRetAttr);
735     Call->addParamAttr(0, SRetAttr);
736   }
737 
738   MangledAlias->setAliasee(GuestExit);
739   return GuestExit;
740 }
741 
742 // Lower an indirect call with inline code.
743 void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
744   assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() &&
745          "Only applicable for Windows targets");
746 
747   IRBuilder<> B(CB);
748   Value *CalledOperand = CB->getCalledOperand();
749 
750   // If the indirect call is called within catchpad or cleanuppad,
751   // we need to copy "funclet" bundle of the call.
752   SmallVector<llvm::OperandBundleDef, 1> Bundles;
753   if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet))
754     Bundles.push_back(OperandBundleDef(*Bundle));
755 
756   // Load the global symbol as a pointer to the check function.
757   Value *GuardFn;
758   if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf"))
759     GuardFn = GuardFnCFGlobal;
760   else
761     GuardFn = GuardFnGlobal;
762   LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
763 
764   // Create new call instruction. The CFGuard check should always be a call,
765   // even if the original CallBase is an Invoke or CallBr instruction.
766   Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes());
767   CallInst *GuardCheck =
768       B.CreateCall(GuardFnType, GuardCheckLoad,
769                    {B.CreateBitCast(CalledOperand, B.getPtrTy()),
770                     B.CreateBitCast(Thunk, B.getPtrTy())},
771                    Bundles);
772 
773   // Ensure that the first argument is passed in the correct register.
774   GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
775 
776   Value *GuardRetVal = B.CreateBitCast(GuardCheck, CalledOperand->getType());
777   CB->setCalledOperand(GuardRetVal);
778 }
779 
780 bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
781   if (!GenerateThunks)
782     return false;
783 
784   M = &Mod;
785 
786   // Check if this module has the cfguard flag and read its value.
787   if (auto *MD =
788           mdconst::extract_or_null<ConstantInt>(M->getModuleFlag("cfguard")))
789     cfguard_module_flag = MD->getZExtValue();
790 
791   PtrTy = PointerType::getUnqual(M->getContext());
792   I64Ty = Type::getInt64Ty(M->getContext());
793   VoidTy = Type::getVoidTy(M->getContext());
794 
795   GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
796   GuardFnPtrType = PointerType::get(GuardFnType, 0);
797   DispatchFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy, PtrTy}, false);
798   DispatchFnPtrType = PointerType::get(DispatchFnType, 0);
799   GuardFnCFGlobal =
800       M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType);
801   GuardFnGlobal =
802       M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType);
803   DispatchFnGlobal =
804       M->getOrInsertGlobal("__os_arm64x_dispatch_call", DispatchFnPtrType);
805 
806   DenseMap<GlobalAlias *, GlobalAlias *> FnsMap;
807   SetVector<GlobalAlias *> PatchableFns;
808 
809   for (Function &F : Mod) {
810     if (!F.hasFnAttribute(Attribute::HybridPatchable) || F.isDeclaration() ||
811         F.hasLocalLinkage() || F.getName().ends_with("$hp_target"))
812       continue;
813 
814     // Rename hybrid patchable functions and change callers to use a global
815     // alias instead.
816     if (std::optional<std::string> MangledName =
817             getArm64ECMangledFunctionName(F.getName().str())) {
818       std::string OrigName(F.getName());
819       F.setName(MangledName.value() + "$hp_target");
820 
821       // The unmangled symbol is a weak alias to an undefined symbol with the
822       // "EXP+" prefix. This undefined symbol is resolved by the linker by
823       // creating an x86 thunk that jumps back to the actual EC target. Since we
824       // can't represent that in IR, we create an alias to the target instead.
825       // The "EXP+" symbol is set as metadata, which is then used by
826       // emitGlobalAlias to emit the right alias.
827       auto *A =
828           GlobalAlias::create(GlobalValue::LinkOnceODRLinkage, OrigName, &F);
829       F.replaceAllUsesWith(A);
830       F.setMetadata("arm64ec_exp_name",
831                     MDNode::get(M->getContext(),
832                                 MDString::get(M->getContext(),
833                                               "EXP+" + MangledName.value())));
834       A->setAliasee(&F);
835 
836       if (F.hasDLLExportStorageClass()) {
837         A->setDLLStorageClass(GlobalValue::DLLExportStorageClass);
838         F.setDLLStorageClass(GlobalValue::DefaultStorageClass);
839       }
840 
841       FnsMap[A] = GlobalAlias::create(GlobalValue::LinkOnceODRLinkage,
842                                       MangledName.value(), &F);
843       PatchableFns.insert(A);
844     }
845   }
846 
847   SetVector<GlobalValue *> DirectCalledFns;
848   for (Function &F : Mod)
849     if (!F.isDeclaration() &&
850         F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
851         F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)
852       processFunction(F, DirectCalledFns, FnsMap);
853 
854   struct ThunkInfo {
855     Constant *Src;
856     Constant *Dst;
857     Arm64ECThunkType Kind;
858   };
859   SmallVector<ThunkInfo> ThunkMapping;
860   for (Function &F : Mod) {
861     if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) &&
862         F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
863         F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) {
864       if (!F.hasComdat())
865         F.setComdat(Mod.getOrInsertComdat(F.getName()));
866       ThunkMapping.push_back(
867           {&F, buildEntryThunk(&F), Arm64ECThunkType::Entry});
868     }
869   }
870   for (GlobalValue *O : DirectCalledFns) {
871     auto GA = dyn_cast<GlobalAlias>(O);
872     auto F = dyn_cast<Function>(GA ? GA->getAliasee() : O);
873     ThunkMapping.push_back(
874         {O, buildExitThunk(F->getFunctionType(), F->getAttributes()),
875          Arm64ECThunkType::Exit});
876     if (!GA && !F->hasDLLImportStorageClass())
877       ThunkMapping.push_back(
878           {buildGuestExitThunk(F), F, Arm64ECThunkType::GuestExit});
879   }
880   for (GlobalAlias *A : PatchableFns) {
881     Function *Thunk = buildPatchableThunk(A, FnsMap[A]);
882     ThunkMapping.push_back({Thunk, A, Arm64ECThunkType::GuestExit});
883   }
884 
885   if (!ThunkMapping.empty()) {
886     SmallVector<Constant *> ThunkMappingArrayElems;
887     for (ThunkInfo &Thunk : ThunkMapping) {
888       ThunkMappingArrayElems.push_back(ConstantStruct::getAnon(
889           {ConstantExpr::getBitCast(Thunk.Src, PtrTy),
890            ConstantExpr::getBitCast(Thunk.Dst, PtrTy),
891            ConstantInt::get(M->getContext(), APInt(32, uint8_t(Thunk.Kind)))}));
892     }
893     Constant *ThunkMappingArray = ConstantArray::get(
894         llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(),
895                              ThunkMappingArrayElems.size()),
896         ThunkMappingArrayElems);
897     new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false,
898                        GlobalValue::ExternalLinkage, ThunkMappingArray,
899                        "llvm.arm64ec.symbolmap");
900   }
901 
902   return true;
903 }
904 
905 bool AArch64Arm64ECCallLowering::processFunction(
906     Function &F, SetVector<GlobalValue *> &DirectCalledFns,
907     DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap) {
908   SmallVector<CallBase *, 8> IndirectCalls;
909 
910   // For ARM64EC targets, a function definition's name is mangled differently
911   // from the normal symbol. We currently have no representation of this sort
912   // of symbol in IR, so we change the name to the mangled name, then store
913   // the unmangled name as metadata.  Later passes that need the unmangled
914   // name (emitting the definition) can grab it from the metadata.
915   //
916   // FIXME: Handle functions with weak linkage?
917   if (!F.hasLocalLinkage() || F.hasAddressTaken()) {
918     if (std::optional<std::string> MangledName =
919             getArm64ECMangledFunctionName(F.getName().str())) {
920       F.setMetadata("arm64ec_unmangled_name",
921                     MDNode::get(M->getContext(),
922                                 MDString::get(M->getContext(), F.getName())));
923       if (F.hasComdat() && F.getComdat()->getName() == F.getName()) {
924         Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value());
925         SmallVector<GlobalObject *> ComdatUsers =
926             to_vector(F.getComdat()->getUsers());
927         for (GlobalObject *User : ComdatUsers)
928           User->setComdat(MangledComdat);
929       }
930       F.setName(MangledName.value());
931     }
932   }
933 
934   // Iterate over the instructions to find all indirect call/invoke/callbr
935   // instructions. Make a separate list of pointers to indirect
936   // call/invoke/callbr instructions because the original instructions will be
937   // deleted as the checks are added.
938   for (BasicBlock &BB : F) {
939     for (Instruction &I : BB) {
940       auto *CB = dyn_cast<CallBase>(&I);
941       if (!CB || CB->getCallingConv() == CallingConv::ARM64EC_Thunk_X64 ||
942           CB->isInlineAsm())
943         continue;
944 
945       // We need to instrument any call that isn't directly calling an
946       // ARM64 function.
947       //
948       // FIXME: getCalledFunction() fails if there's a bitcast (e.g.
949       // unprototyped functions in C)
950       if (Function *F = CB->getCalledFunction()) {
951         if (!LowerDirectToIndirect || F->hasLocalLinkage() ||
952             F->isIntrinsic() || !F->isDeclaration())
953           continue;
954 
955         DirectCalledFns.insert(F);
956         continue;
957       }
958 
959       // Use mangled global alias for direct calls to patchable functions.
960       if (GlobalAlias *A = dyn_cast<GlobalAlias>(CB->getCalledOperand())) {
961         auto I = FnsMap.find(A);
962         if (I != FnsMap.end()) {
963           CB->setCalledOperand(I->second);
964           DirectCalledFns.insert(I->first);
965           continue;
966         }
967       }
968 
969       IndirectCalls.push_back(CB);
970       ++Arm64ECCallsLowered;
971     }
972   }
973 
974   if (IndirectCalls.empty())
975     return false;
976 
977   for (CallBase *CB : IndirectCalls)
978     lowerCall(CB);
979 
980   return true;
981 }
982 
983 char AArch64Arm64ECCallLowering::ID = 0;
984 INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering",
985                 "AArch64Arm64ECCallLowering", false, false)
986 
987 ModulePass *llvm::createAArch64Arm64ECCallLoweringPass() {
988   return new AArch64Arm64ECCallLowering;
989 }
990