xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp (revision b9128a37faafede823eb456aa65a11ac69997284)
1 //===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file contains the IR transform to lower external or indirect calls for
11 /// the ARM64EC calling convention. Such calls must go through the runtime, so
12 /// we can translate the calling convention for calls into the emulator.
13 ///
14 /// This subsumes Control Flow Guard handling.
15 ///
16 //===----------------------------------------------------------------------===//
17 
18 #include "AArch64.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/IR/CallingConv.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/Instruction.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Object/COFF.h"
28 #include "llvm/Pass.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/TargetParser/Triple.h"
31 
32 using namespace llvm;
33 using namespace llvm::object;
34 
35 using OperandBundleDef = OperandBundleDefT<Value *>;
36 
37 #define DEBUG_TYPE "arm64eccalllowering"
38 
39 STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered");
40 
41 static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect",
42                                            cl::Hidden, cl::init(true));
43 static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden,
44                                     cl::init(true));
45 
46 namespace {
47 
48 enum class ThunkType { GuestExit, Entry, Exit };
49 
50 class AArch64Arm64ECCallLowering : public ModulePass {
51 public:
52   static char ID;
53   AArch64Arm64ECCallLowering() : ModulePass(ID) {
54     initializeAArch64Arm64ECCallLoweringPass(*PassRegistry::getPassRegistry());
55   }
56 
57   Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs);
58   Function *buildEntryThunk(Function *F);
59   void lowerCall(CallBase *CB);
60   Function *buildGuestExitThunk(Function *F);
61   bool processFunction(Function &F, SetVector<Function *> &DirectCalledFns);
62   bool runOnModule(Module &M) override;
63 
64 private:
65   int cfguard_module_flag = 0;
66   FunctionType *GuardFnType = nullptr;
67   PointerType *GuardFnPtrType = nullptr;
68   Constant *GuardFnCFGlobal = nullptr;
69   Constant *GuardFnGlobal = nullptr;
70   Module *M = nullptr;
71 
72   Type *PtrTy;
73   Type *I64Ty;
74   Type *VoidTy;
75 
76   void getThunkType(FunctionType *FT, AttributeList AttrList, ThunkType TT,
77                     raw_ostream &Out, FunctionType *&Arm64Ty,
78                     FunctionType *&X64Ty);
79   void getThunkRetType(FunctionType *FT, AttributeList AttrList,
80                        raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy,
81                        SmallVectorImpl<Type *> &Arm64ArgTypes,
82                        SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr);
83   void getThunkArgTypes(FunctionType *FT, AttributeList AttrList, ThunkType TT,
84                         raw_ostream &Out,
85                         SmallVectorImpl<Type *> &Arm64ArgTypes,
86                         SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr);
87   void canonicalizeThunkType(Type *T, Align Alignment, bool Ret,
88                              uint64_t ArgSizeBytes, raw_ostream &Out,
89                              Type *&Arm64Ty, Type *&X64Ty);
90 };
91 
92 } // end anonymous namespace
93 
94 void AArch64Arm64ECCallLowering::getThunkType(FunctionType *FT,
95                                               AttributeList AttrList,
96                                               ThunkType TT, raw_ostream &Out,
97                                               FunctionType *&Arm64Ty,
98                                               FunctionType *&X64Ty) {
99   Out << (TT == ThunkType::Entry ? "$ientry_thunk$cdecl$"
100                                  : "$iexit_thunk$cdecl$");
101 
102   Type *Arm64RetTy;
103   Type *X64RetTy;
104 
105   SmallVector<Type *> Arm64ArgTypes;
106   SmallVector<Type *> X64ArgTypes;
107 
108   // The first argument to a thunk is the called function, stored in x9.
109   // For exit thunks, we pass the called function down to the emulator;
110   // for entry/guest exit thunks, we just call the Arm64 function directly.
111   if (TT == ThunkType::Exit)
112     Arm64ArgTypes.push_back(PtrTy);
113   X64ArgTypes.push_back(PtrTy);
114 
115   bool HasSretPtr = false;
116   getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
117                   X64ArgTypes, HasSretPtr);
118 
119   getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,
120                    HasSretPtr);
121 
122   Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false);
123 
124   X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false);
125 }
126 
127 void AArch64Arm64ECCallLowering::getThunkArgTypes(
128     FunctionType *FT, AttributeList AttrList, ThunkType TT, raw_ostream &Out,
129     SmallVectorImpl<Type *> &Arm64ArgTypes,
130     SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr) {
131 
132   Out << "$";
133   if (FT->isVarArg()) {
134     // We treat the variadic function's thunk as a normal function
135     // with the following type on the ARM side:
136     //   rettype exitthunk(
137     //     ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5)
138     //
139     // that can coverage all types of variadic function.
140     // x9 is similar to normal exit thunk, store the called function.
141     // x0-x3 is the arguments be stored in registers.
142     // x4 is the address of the arguments on the stack.
143     // x5 is the size of the arguments on the stack.
144     //
145     // On the x64 side, it's the same except that x5 isn't set.
146     //
147     // If both the ARM and X64 sides are sret, there are only three
148     // arguments in registers.
149     //
150     // If the X64 side is sret, but the ARM side isn't, we pass an extra value
151     // to/from the X64 side, and let SelectionDAG transform it into a memory
152     // location.
153     Out << "varargs";
154 
155     // x0-x3
156     for (int i = HasSretPtr ? 1 : 0; i < 4; i++) {
157       Arm64ArgTypes.push_back(I64Ty);
158       X64ArgTypes.push_back(I64Ty);
159     }
160 
161     // x4
162     Arm64ArgTypes.push_back(PtrTy);
163     X64ArgTypes.push_back(PtrTy);
164     // x5
165     Arm64ArgTypes.push_back(I64Ty);
166     if (TT != ThunkType::Entry) {
167       // FIXME: x5 isn't actually used by the x64 side; revisit once we
168       // have proper isel for varargs
169       X64ArgTypes.push_back(I64Ty);
170     }
171     return;
172   }
173 
174   unsigned I = 0;
175   if (HasSretPtr)
176     I++;
177 
178   if (I == FT->getNumParams()) {
179     Out << "v";
180     return;
181   }
182 
183   for (unsigned E = FT->getNumParams(); I != E; ++I) {
184     Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne();
185 #if 0
186     // FIXME: Need more information about argument size; see
187     // https://reviews.llvm.org/D132926
188     uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I);
189 #else
190     uint64_t ArgSizeBytes = 0;
191 #endif
192     Type *Arm64Ty, *X64Ty;
193     canonicalizeThunkType(FT->getParamType(I), ParamAlign,
194                           /*Ret*/ false, ArgSizeBytes, Out, Arm64Ty, X64Ty);
195     Arm64ArgTypes.push_back(Arm64Ty);
196     X64ArgTypes.push_back(X64Ty);
197   }
198 }
199 
200 void AArch64Arm64ECCallLowering::getThunkRetType(
201     FunctionType *FT, AttributeList AttrList, raw_ostream &Out,
202     Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes,
203     SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr) {
204   Type *T = FT->getReturnType();
205 #if 0
206   // FIXME: Need more information about argument size; see
207   // https://reviews.llvm.org/D132926
208   uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes();
209 #else
210   int64_t ArgSizeBytes = 0;
211 #endif
212   if (T->isVoidTy()) {
213     if (FT->getNumParams()) {
214       auto SRetAttr = AttrList.getParamAttr(0, Attribute::StructRet);
215       auto InRegAttr = AttrList.getParamAttr(0, Attribute::InReg);
216       if (SRetAttr.isValid() && InRegAttr.isValid()) {
217         // sret+inreg indicates a call that returns a C++ class value. This is
218         // actually equivalent to just passing and returning a void* pointer
219         // as the first argument. Translate it that way, instead of trying
220         // to model "inreg" in the thunk's calling convention, to simplify
221         // the rest of the code.
222         Out << "i8";
223         Arm64RetTy = I64Ty;
224         X64RetTy = I64Ty;
225         return;
226       }
227       if (SRetAttr.isValid()) {
228         // FIXME: Sanity-check the sret type; if it's an integer or pointer,
229         // we'll get screwy mangling/codegen.
230         // FIXME: For large struct types, mangle as an integer argument and
231         // integer return, so we can reuse more thunks, instead of "m" syntax.
232         // (MSVC mangles this case as an integer return with no argument, but
233         // that's a miscompile.)
234         Type *SRetType = SRetAttr.getValueAsType();
235         Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne();
236         Type *Arm64Ty, *X64Ty;
237         canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes,
238                               Out, Arm64Ty, X64Ty);
239         Arm64RetTy = VoidTy;
240         X64RetTy = VoidTy;
241         Arm64ArgTypes.push_back(FT->getParamType(0));
242         X64ArgTypes.push_back(FT->getParamType(0));
243         HasSretPtr = true;
244         return;
245       }
246     }
247 
248     Out << "v";
249     Arm64RetTy = VoidTy;
250     X64RetTy = VoidTy;
251     return;
252   }
253 
254   canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out, Arm64RetTy,
255                         X64RetTy);
256   if (X64RetTy->isPointerTy()) {
257     // If the X64 type is canonicalized to a pointer, that means it's
258     // passed/returned indirectly. For a return value, that means it's an
259     // sret pointer.
260     X64ArgTypes.push_back(X64RetTy);
261     X64RetTy = VoidTy;
262   }
263 }
264 
265 void AArch64Arm64ECCallLowering::canonicalizeThunkType(
266     Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes, raw_ostream &Out,
267     Type *&Arm64Ty, Type *&X64Ty) {
268   if (T->isFloatTy()) {
269     Out << "f";
270     Arm64Ty = T;
271     X64Ty = T;
272     return;
273   }
274 
275   if (T->isDoubleTy()) {
276     Out << "d";
277     Arm64Ty = T;
278     X64Ty = T;
279     return;
280   }
281 
282   if (T->isFloatingPointTy()) {
283     report_fatal_error(
284         "Only 32 and 64 bit floating points are supported for ARM64EC thunks");
285   }
286 
287   auto &DL = M->getDataLayout();
288 
289   if (auto *StructTy = dyn_cast<StructType>(T))
290     if (StructTy->getNumElements() == 1)
291       T = StructTy->getElementType(0);
292 
293   if (T->isArrayTy()) {
294     Type *ElementTy = T->getArrayElementType();
295     uint64_t ElementCnt = T->getArrayNumElements();
296     uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8;
297     uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes;
298     if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) {
299       Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes;
300       if (Alignment.value() >= 8 && !T->isPointerTy())
301         Out << "a" << Alignment.value();
302       Arm64Ty = T;
303       if (TotalSizeBytes <= 8) {
304         // Arm64 returns small structs of float/double in float registers;
305         // X64 uses RAX.
306         X64Ty = llvm::Type::getIntNTy(M->getContext(), TotalSizeBytes * 8);
307       } else {
308         // Struct is passed directly on Arm64, but indirectly on X64.
309         X64Ty = PtrTy;
310       }
311       return;
312     } else if (T->isFloatingPointTy()) {
313       report_fatal_error("Only 32 and 64 bit floating points are supported for "
314                          "ARM64EC thunks");
315     }
316   }
317 
318   if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) {
319     Out << "i8";
320     Arm64Ty = I64Ty;
321     X64Ty = I64Ty;
322     return;
323   }
324 
325   unsigned TypeSize = ArgSizeBytes;
326   if (TypeSize == 0)
327     TypeSize = DL.getTypeSizeInBits(T) / 8;
328   Out << "m";
329   if (TypeSize != 4)
330     Out << TypeSize;
331   if (Alignment.value() >= 8 && !T->isPointerTy())
332     Out << "a" << Alignment.value();
333   // FIXME: Try to canonicalize Arm64Ty more thoroughly?
334   Arm64Ty = T;
335   if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) {
336     // Pass directly in an integer register
337     X64Ty = llvm::Type::getIntNTy(M->getContext(), TypeSize * 8);
338   } else {
339     // Passed directly on Arm64, but indirectly on X64.
340     X64Ty = PtrTy;
341   }
342 }
343 
344 // This function builds the "exit thunk", a function which translates
345 // arguments and return values when calling x64 code from AArch64 code.
346 Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
347                                                      AttributeList Attrs) {
348   SmallString<256> ExitThunkName;
349   llvm::raw_svector_ostream ExitThunkStream(ExitThunkName);
350   FunctionType *Arm64Ty, *X64Ty;
351   getThunkType(FT, Attrs, ThunkType::Exit, ExitThunkStream, Arm64Ty, X64Ty);
352   if (Function *F = M->getFunction(ExitThunkName))
353     return F;
354 
355   Function *F = Function::Create(Arm64Ty, GlobalValue::LinkOnceODRLinkage, 0,
356                                  ExitThunkName, M);
357   F->setCallingConv(CallingConv::ARM64EC_Thunk_Native);
358   F->setSection(".wowthk$aa");
359   F->setComdat(M->getOrInsertComdat(ExitThunkName));
360   // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
361   F->addFnAttr("frame-pointer", "all");
362   // Only copy sret from the first argument. For C++ instance methods, clang can
363   // stick an sret marking on a later argument, but it doesn't actually affect
364   // the ABI, so we can omit it. This avoids triggering a verifier assertion.
365   if (FT->getNumParams()) {
366     auto SRet = Attrs.getParamAttr(0, Attribute::StructRet);
367     auto InReg = Attrs.getParamAttr(0, Attribute::InReg);
368     if (SRet.isValid() && !InReg.isValid())
369       F->addParamAttr(1, SRet);
370   }
371   // FIXME: Copy anything other than sret?  Shouldn't be necessary for normal
372   // C ABI, but might show up in other cases.
373   BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F);
374   IRBuilder<> IRB(BB);
375   Value *CalleePtr =
376       M->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy);
377   Value *Callee = IRB.CreateLoad(PtrTy, CalleePtr);
378   auto &DL = M->getDataLayout();
379   SmallVector<Value *> Args;
380 
381   // Pass the called function in x9.
382   Args.push_back(F->arg_begin());
383 
384   Type *RetTy = Arm64Ty->getReturnType();
385   if (RetTy != X64Ty->getReturnType()) {
386     // If the return type is an array or struct, translate it. Values of size
387     // 8 or less go into RAX; bigger values go into memory, and we pass a
388     // pointer.
389     if (DL.getTypeStoreSize(RetTy) > 8) {
390       Args.push_back(IRB.CreateAlloca(RetTy));
391     }
392   }
393 
394   for (auto &Arg : make_range(F->arg_begin() + 1, F->arg_end())) {
395     // Translate arguments from AArch64 calling convention to x86 calling
396     // convention.
397     //
398     // For simple types, we don't need to do any translation: they're
399     // represented the same way. (Implicit sign extension is not part of
400     // either convention.)
401     //
402     // The big thing we have to worry about is struct types... but
403     // fortunately AArch64 clang is pretty friendly here: the cases that need
404     // translation are always passed as a struct or array. (If we run into
405     // some cases where this doesn't work, we can teach clang to mark it up
406     // with an attribute.)
407     //
408     // The first argument is the called function, stored in x9.
409     if (Arg.getType()->isArrayTy() || Arg.getType()->isStructTy() ||
410         DL.getTypeStoreSize(Arg.getType()) > 8) {
411       Value *Mem = IRB.CreateAlloca(Arg.getType());
412       IRB.CreateStore(&Arg, Mem);
413       if (DL.getTypeStoreSize(Arg.getType()) <= 8) {
414         Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType()));
415         Args.push_back(IRB.CreateLoad(IntTy, IRB.CreateBitCast(Mem, PtrTy)));
416       } else
417         Args.push_back(Mem);
418     } else {
419       Args.push_back(&Arg);
420     }
421   }
422   // FIXME: Transfer necessary attributes? sret? anything else?
423 
424   Callee = IRB.CreateBitCast(Callee, PtrTy);
425   CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args);
426   Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
427 
428   Value *RetVal = Call;
429   if (RetTy != X64Ty->getReturnType()) {
430     // If we rewrote the return type earlier, convert the return value to
431     // the proper type.
432     if (DL.getTypeStoreSize(RetTy) > 8) {
433       RetVal = IRB.CreateLoad(RetTy, Args[1]);
434     } else {
435       Value *CastAlloca = IRB.CreateAlloca(RetTy);
436       IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
437       RetVal = IRB.CreateLoad(RetTy, CastAlloca);
438     }
439   }
440 
441   if (RetTy->isVoidTy())
442     IRB.CreateRetVoid();
443   else
444     IRB.CreateRet(RetVal);
445   return F;
446 }
447 
448 // This function builds the "entry thunk", a function which translates
449 // arguments and return values when calling AArch64 code from x64 code.
450 Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
451   SmallString<256> EntryThunkName;
452   llvm::raw_svector_ostream EntryThunkStream(EntryThunkName);
453   FunctionType *Arm64Ty, *X64Ty;
454   getThunkType(F->getFunctionType(), F->getAttributes(), ThunkType::Entry,
455                EntryThunkStream, Arm64Ty, X64Ty);
456   if (Function *F = M->getFunction(EntryThunkName))
457     return F;
458 
459   Function *Thunk = Function::Create(X64Ty, GlobalValue::LinkOnceODRLinkage, 0,
460                                      EntryThunkName, M);
461   Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
462   Thunk->setSection(".wowthk$aa");
463   Thunk->setComdat(M->getOrInsertComdat(EntryThunkName));
464   // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
465   Thunk->addFnAttr("frame-pointer", "all");
466 
467   auto &DL = M->getDataLayout();
468   BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk);
469   IRBuilder<> IRB(BB);
470 
471   Type *RetTy = Arm64Ty->getReturnType();
472   Type *X64RetType = X64Ty->getReturnType();
473 
474   bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy();
475   unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
476   unsigned PassthroughArgSize = F->isVarArg() ? 5 : Thunk->arg_size();
477 
478   // Translate arguments to call.
479   SmallVector<Value *> Args;
480   for (unsigned i = ThunkArgOffset, e = PassthroughArgSize; i != e; ++i) {
481     Value *Arg = Thunk->getArg(i);
482     Type *ArgTy = Arm64Ty->getParamType(i - ThunkArgOffset);
483     if (ArgTy->isArrayTy() || ArgTy->isStructTy() ||
484         DL.getTypeStoreSize(ArgTy) > 8) {
485       // Translate array/struct arguments to the expected type.
486       if (DL.getTypeStoreSize(ArgTy) <= 8) {
487         Value *CastAlloca = IRB.CreateAlloca(ArgTy);
488         IRB.CreateStore(Arg, IRB.CreateBitCast(CastAlloca, PtrTy));
489         Arg = IRB.CreateLoad(ArgTy, CastAlloca);
490       } else {
491         Arg = IRB.CreateLoad(ArgTy, IRB.CreateBitCast(Arg, PtrTy));
492       }
493     }
494     Args.push_back(Arg);
495   }
496 
497   if (F->isVarArg()) {
498     // The 5th argument to variadic entry thunks is used to model the x64 sp
499     // which is passed to the thunk in x4, this can be passed to the callee as
500     // the variadic argument start address after skipping over the 32 byte
501     // shadow store.
502 
503     // The EC thunk CC will assign any argument marked as InReg to x4.
504     Thunk->addParamAttr(5, Attribute::InReg);
505     Value *Arg = Thunk->getArg(5);
506     Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20));
507     Args.push_back(Arg);
508 
509     // Pass in a zero variadic argument size (in x5).
510     Args.push_back(IRB.getInt64(0));
511   }
512 
513   // Call the function passed to the thunk.
514   Value *Callee = Thunk->getArg(0);
515   Callee = IRB.CreateBitCast(Callee, PtrTy);
516   Value *Call = IRB.CreateCall(Arm64Ty, Callee, Args);
517 
518   Value *RetVal = Call;
519   if (TransformDirectToSRet) {
520     IRB.CreateStore(RetVal, IRB.CreateBitCast(Thunk->getArg(1), PtrTy));
521   } else if (X64RetType != RetTy) {
522     Value *CastAlloca = IRB.CreateAlloca(X64RetType);
523     IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
524     RetVal = IRB.CreateLoad(X64RetType, CastAlloca);
525   }
526 
527   // Return to the caller.  Note that the isel has code to translate this
528   // "ret" to a tail call to __os_arm64x_dispatch_ret.  (Alternatively, we
529   // could emit a tail call here, but that would require a dedicated calling
530   // convention, which seems more complicated overall.)
531   if (X64RetType->isVoidTy())
532     IRB.CreateRetVoid();
533   else
534     IRB.CreateRet(RetVal);
535 
536   return Thunk;
537 }
538 
539 // Builds the "guest exit thunk", a helper to call a function which may or may
540 // not be an exit thunk. (We optimistically assume non-dllimport function
541 // declarations refer to functions defined in AArch64 code; if the linker
542 // can't prove that, we use this routine instead.)
543 Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
544   llvm::raw_null_ostream NullThunkName;
545   FunctionType *Arm64Ty, *X64Ty;
546   getThunkType(F->getFunctionType(), F->getAttributes(), ThunkType::GuestExit,
547                NullThunkName, Arm64Ty, X64Ty);
548   auto MangledName = getArm64ECMangledFunctionName(F->getName().str());
549   assert(MangledName && "Can't guest exit to function that's already native");
550   std::string ThunkName = *MangledName;
551   if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
552     ThunkName.insert(ThunkName.find("@"), "$exit_thunk");
553   } else {
554     ThunkName.append("$exit_thunk");
555   }
556   Function *GuestExit =
557       Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
558   GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
559   GuestExit->setSection(".wowthk$aa");
560   GuestExit->setMetadata(
561       "arm64ec_unmangled_name",
562       MDNode::get(M->getContext(),
563                   MDString::get(M->getContext(), F->getName())));
564   GuestExit->setMetadata(
565       "arm64ec_ecmangled_name",
566       MDNode::get(M->getContext(),
567                   MDString::get(M->getContext(), *MangledName)));
568   F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {}));
569   BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
570   IRBuilder<> B(BB);
571 
572   // Load the global symbol as a pointer to the check function.
573   Value *GuardFn;
574   if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf"))
575     GuardFn = GuardFnCFGlobal;
576   else
577     GuardFn = GuardFnGlobal;
578   LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
579 
580   // Create new call instruction. The CFGuard check should always be a call,
581   // even if the original CallBase is an Invoke or CallBr instruction.
582   Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes());
583   CallInst *GuardCheck = B.CreateCall(
584       GuardFnType, GuardCheckLoad,
585       {B.CreateBitCast(F, B.getPtrTy()), B.CreateBitCast(Thunk, B.getPtrTy())});
586 
587   // Ensure that the first argument is passed in the correct register.
588   GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
589 
590   Value *GuardRetVal = B.CreateBitCast(GuardCheck, PtrTy);
591   SmallVector<Value *> Args;
592   for (Argument &Arg : GuestExit->args())
593     Args.push_back(&Arg);
594   CallInst *Call = B.CreateCall(Arm64Ty, GuardRetVal, Args);
595   Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
596 
597   if (Call->getType()->isVoidTy())
598     B.CreateRetVoid();
599   else
600     B.CreateRet(Call);
601 
602   auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
603   auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
604   if (SRetAttr.isValid() && !InRegAttr.isValid()) {
605     GuestExit->addParamAttr(0, SRetAttr);
606     Call->addParamAttr(0, SRetAttr);
607   }
608 
609   return GuestExit;
610 }
611 
612 // Lower an indirect call with inline code.
613 void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
614   assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() &&
615          "Only applicable for Windows targets");
616 
617   IRBuilder<> B(CB);
618   Value *CalledOperand = CB->getCalledOperand();
619 
620   // If the indirect call is called within catchpad or cleanuppad,
621   // we need to copy "funclet" bundle of the call.
622   SmallVector<llvm::OperandBundleDef, 1> Bundles;
623   if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet))
624     Bundles.push_back(OperandBundleDef(*Bundle));
625 
626   // Load the global symbol as a pointer to the check function.
627   Value *GuardFn;
628   if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf"))
629     GuardFn = GuardFnCFGlobal;
630   else
631     GuardFn = GuardFnGlobal;
632   LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
633 
634   // Create new call instruction. The CFGuard check should always be a call,
635   // even if the original CallBase is an Invoke or CallBr instruction.
636   Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes());
637   CallInst *GuardCheck =
638       B.CreateCall(GuardFnType, GuardCheckLoad,
639                    {B.CreateBitCast(CalledOperand, B.getPtrTy()),
640                     B.CreateBitCast(Thunk, B.getPtrTy())},
641                    Bundles);
642 
643   // Ensure that the first argument is passed in the correct register.
644   GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
645 
646   Value *GuardRetVal = B.CreateBitCast(GuardCheck, CalledOperand->getType());
647   CB->setCalledOperand(GuardRetVal);
648 }
649 
650 bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
651   if (!GenerateThunks)
652     return false;
653 
654   M = &Mod;
655 
656   // Check if this module has the cfguard flag and read its value.
657   if (auto *MD =
658           mdconst::extract_or_null<ConstantInt>(M->getModuleFlag("cfguard")))
659     cfguard_module_flag = MD->getZExtValue();
660 
661   PtrTy = PointerType::getUnqual(M->getContext());
662   I64Ty = Type::getInt64Ty(M->getContext());
663   VoidTy = Type::getVoidTy(M->getContext());
664 
665   GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
666   GuardFnPtrType = PointerType::get(GuardFnType, 0);
667   GuardFnCFGlobal =
668       M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType);
669   GuardFnGlobal =
670       M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType);
671 
672   SetVector<Function *> DirectCalledFns;
673   for (Function &F : Mod)
674     if (!F.isDeclaration() &&
675         F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
676         F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)
677       processFunction(F, DirectCalledFns);
678 
679   struct ThunkInfo {
680     Constant *Src;
681     Constant *Dst;
682     unsigned Kind;
683   };
684   SmallVector<ThunkInfo> ThunkMapping;
685   for (Function &F : Mod) {
686     if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) &&
687         F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
688         F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) {
689       if (!F.hasComdat())
690         F.setComdat(Mod.getOrInsertComdat(F.getName()));
691       ThunkMapping.push_back({&F, buildEntryThunk(&F), 1});
692     }
693   }
694   for (Function *F : DirectCalledFns) {
695     ThunkMapping.push_back(
696         {F, buildExitThunk(F->getFunctionType(), F->getAttributes()), 4});
697     if (!F->hasDLLImportStorageClass())
698       ThunkMapping.push_back({buildGuestExitThunk(F), F, 0});
699   }
700 
701   if (!ThunkMapping.empty()) {
702     SmallVector<Constant *> ThunkMappingArrayElems;
703     for (ThunkInfo &Thunk : ThunkMapping) {
704       ThunkMappingArrayElems.push_back(ConstantStruct::getAnon(
705           {ConstantExpr::getBitCast(Thunk.Src, PtrTy),
706            ConstantExpr::getBitCast(Thunk.Dst, PtrTy),
707            ConstantInt::get(M->getContext(), APInt(32, Thunk.Kind))}));
708     }
709     Constant *ThunkMappingArray = ConstantArray::get(
710         llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(),
711                              ThunkMappingArrayElems.size()),
712         ThunkMappingArrayElems);
713     new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false,
714                        GlobalValue::ExternalLinkage, ThunkMappingArray,
715                        "llvm.arm64ec.symbolmap");
716   }
717 
718   return true;
719 }
720 
721 bool AArch64Arm64ECCallLowering::processFunction(
722     Function &F, SetVector<Function *> &DirectCalledFns) {
723   SmallVector<CallBase *, 8> IndirectCalls;
724 
725   // For ARM64EC targets, a function definition's name is mangled differently
726   // from the normal symbol. We currently have no representation of this sort
727   // of symbol in IR, so we change the name to the mangled name, then store
728   // the unmangled name as metadata.  Later passes that need the unmangled
729   // name (emitting the definition) can grab it from the metadata.
730   //
731   // FIXME: Handle functions with weak linkage?
732   if (F.hasExternalLinkage() || F.hasWeakLinkage() || F.hasLinkOnceLinkage()) {
733     if (std::optional<std::string> MangledName =
734             getArm64ECMangledFunctionName(F.getName().str())) {
735       F.setMetadata("arm64ec_unmangled_name",
736                     MDNode::get(M->getContext(),
737                                 MDString::get(M->getContext(), F.getName())));
738       if (F.hasComdat() && F.getComdat()->getName() == F.getName()) {
739         Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value());
740         SmallVector<GlobalObject *> ComdatUsers =
741             to_vector(F.getComdat()->getUsers());
742         for (GlobalObject *User : ComdatUsers)
743           User->setComdat(MangledComdat);
744       }
745       F.setName(MangledName.value());
746     }
747   }
748 
749   // Iterate over the instructions to find all indirect call/invoke/callbr
750   // instructions. Make a separate list of pointers to indirect
751   // call/invoke/callbr instructions because the original instructions will be
752   // deleted as the checks are added.
753   for (BasicBlock &BB : F) {
754     for (Instruction &I : BB) {
755       auto *CB = dyn_cast<CallBase>(&I);
756       if (!CB || CB->getCallingConv() == CallingConv::ARM64EC_Thunk_X64 ||
757           CB->isInlineAsm())
758         continue;
759 
760       // We need to instrument any call that isn't directly calling an
761       // ARM64 function.
762       //
763       // FIXME: getCalledFunction() fails if there's a bitcast (e.g.
764       // unprototyped functions in C)
765       if (Function *F = CB->getCalledFunction()) {
766         if (!LowerDirectToIndirect || F->hasLocalLinkage() ||
767             F->isIntrinsic() || !F->isDeclaration())
768           continue;
769 
770         DirectCalledFns.insert(F);
771         continue;
772       }
773 
774       IndirectCalls.push_back(CB);
775       ++Arm64ECCallsLowered;
776     }
777   }
778 
779   if (IndirectCalls.empty())
780     return false;
781 
782   for (CallBase *CB : IndirectCalls)
783     lowerCall(CB);
784 
785   return true;
786 }
787 
788 char AArch64Arm64ECCallLowering::ID = 0;
789 INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering",
790                 "AArch64Arm64ECCallLowering", false, false)
791 
792 ModulePass *llvm::createAArch64Arm64ECCallLoweringPass() {
793   return new AArch64Arm64ECCallLowering;
794 }
795