xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp (revision 3a0793336edfc21cb6d4c8c5c5d7f1665f3e6c5a)
17a6dacacSDimitry Andric //===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===//
27a6dacacSDimitry Andric //
37a6dacacSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47a6dacacSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
57a6dacacSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67a6dacacSDimitry Andric //
77a6dacacSDimitry Andric //===----------------------------------------------------------------------===//
87a6dacacSDimitry Andric ///
97a6dacacSDimitry Andric /// \file
107a6dacacSDimitry Andric /// This file contains the IR transform to lower external or indirect calls for
117a6dacacSDimitry Andric /// the ARM64EC calling convention. Such calls must go through the runtime, so
127a6dacacSDimitry Andric /// we can translate the calling convention for calls into the emulator.
137a6dacacSDimitry Andric ///
147a6dacacSDimitry Andric /// This subsumes Control Flow Guard handling.
157a6dacacSDimitry Andric ///
167a6dacacSDimitry Andric //===----------------------------------------------------------------------===//
177a6dacacSDimitry Andric 
187a6dacacSDimitry Andric #include "AArch64.h"
197a6dacacSDimitry Andric #include "llvm/ADT/SetVector.h"
207a6dacacSDimitry Andric #include "llvm/ADT/SmallString.h"
217a6dacacSDimitry Andric #include "llvm/ADT/SmallVector.h"
227a6dacacSDimitry Andric #include "llvm/ADT/Statistic.h"
237a6dacacSDimitry Andric #include "llvm/IR/CallingConv.h"
247a6dacacSDimitry Andric #include "llvm/IR/IRBuilder.h"
257a6dacacSDimitry Andric #include "llvm/IR/Instruction.h"
267a6dacacSDimitry Andric #include "llvm/InitializePasses.h"
27439352acSDimitry Andric #include "llvm/Object/COFF.h"
287a6dacacSDimitry Andric #include "llvm/Pass.h"
297a6dacacSDimitry Andric #include "llvm/Support/CommandLine.h"
307a6dacacSDimitry Andric #include "llvm/TargetParser/Triple.h"
317a6dacacSDimitry Andric 
327a6dacacSDimitry Andric using namespace llvm;
33439352acSDimitry Andric using namespace llvm::object;
347a6dacacSDimitry Andric 
357a6dacacSDimitry Andric using OperandBundleDef = OperandBundleDefT<Value *>;
367a6dacacSDimitry Andric 
377a6dacacSDimitry Andric #define DEBUG_TYPE "arm64eccalllowering"
387a6dacacSDimitry Andric 
397a6dacacSDimitry Andric STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered");
407a6dacacSDimitry Andric 
417a6dacacSDimitry Andric static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect",
427a6dacacSDimitry Andric                                            cl::Hidden, cl::init(true));
437a6dacacSDimitry Andric static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden,
447a6dacacSDimitry Andric                                     cl::init(true));
457a6dacacSDimitry Andric 
467a6dacacSDimitry Andric namespace {
477a6dacacSDimitry Andric 
48439352acSDimitry Andric enum class ThunkType { GuestExit, Entry, Exit };
49439352acSDimitry Andric 
507a6dacacSDimitry Andric class AArch64Arm64ECCallLowering : public ModulePass {
517a6dacacSDimitry Andric public:
527a6dacacSDimitry Andric   static char ID;
537a6dacacSDimitry Andric   AArch64Arm64ECCallLowering() : ModulePass(ID) {
547a6dacacSDimitry Andric     initializeAArch64Arm64ECCallLoweringPass(*PassRegistry::getPassRegistry());
557a6dacacSDimitry Andric   }
567a6dacacSDimitry Andric 
577a6dacacSDimitry Andric   Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs);
587a6dacacSDimitry Andric   Function *buildEntryThunk(Function *F);
597a6dacacSDimitry Andric   void lowerCall(CallBase *CB);
607a6dacacSDimitry Andric   Function *buildGuestExitThunk(Function *F);
617a6dacacSDimitry Andric   bool processFunction(Function &F, SetVector<Function *> &DirectCalledFns);
627a6dacacSDimitry Andric   bool runOnModule(Module &M) override;
637a6dacacSDimitry Andric 
647a6dacacSDimitry Andric private:
657a6dacacSDimitry Andric   int cfguard_module_flag = 0;
667a6dacacSDimitry Andric   FunctionType *GuardFnType = nullptr;
677a6dacacSDimitry Andric   PointerType *GuardFnPtrType = nullptr;
687a6dacacSDimitry Andric   Constant *GuardFnCFGlobal = nullptr;
697a6dacacSDimitry Andric   Constant *GuardFnGlobal = nullptr;
707a6dacacSDimitry Andric   Module *M = nullptr;
717a6dacacSDimitry Andric 
727a6dacacSDimitry Andric   Type *PtrTy;
737a6dacacSDimitry Andric   Type *I64Ty;
747a6dacacSDimitry Andric   Type *VoidTy;
757a6dacacSDimitry Andric 
76439352acSDimitry Andric   void getThunkType(FunctionType *FT, AttributeList AttrList, ThunkType TT,
777a6dacacSDimitry Andric                     raw_ostream &Out, FunctionType *&Arm64Ty,
787a6dacacSDimitry Andric                     FunctionType *&X64Ty);
797a6dacacSDimitry Andric   void getThunkRetType(FunctionType *FT, AttributeList AttrList,
807a6dacacSDimitry Andric                        raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy,
817a6dacacSDimitry Andric                        SmallVectorImpl<Type *> &Arm64ArgTypes,
827a6dacacSDimitry Andric                        SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr);
83439352acSDimitry Andric   void getThunkArgTypes(FunctionType *FT, AttributeList AttrList, ThunkType TT,
847a6dacacSDimitry Andric                         raw_ostream &Out,
857a6dacacSDimitry Andric                         SmallVectorImpl<Type *> &Arm64ArgTypes,
867a6dacacSDimitry Andric                         SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr);
877a6dacacSDimitry Andric   void canonicalizeThunkType(Type *T, Align Alignment, bool Ret,
887a6dacacSDimitry Andric                              uint64_t ArgSizeBytes, raw_ostream &Out,
897a6dacacSDimitry Andric                              Type *&Arm64Ty, Type *&X64Ty);
907a6dacacSDimitry Andric };
917a6dacacSDimitry Andric 
927a6dacacSDimitry Andric } // end anonymous namespace
937a6dacacSDimitry Andric 
947a6dacacSDimitry Andric void AArch64Arm64ECCallLowering::getThunkType(FunctionType *FT,
957a6dacacSDimitry Andric                                               AttributeList AttrList,
96439352acSDimitry Andric                                               ThunkType TT, raw_ostream &Out,
977a6dacacSDimitry Andric                                               FunctionType *&Arm64Ty,
987a6dacacSDimitry Andric                                               FunctionType *&X64Ty) {
99439352acSDimitry Andric   Out << (TT == ThunkType::Entry ? "$ientry_thunk$cdecl$"
100439352acSDimitry Andric                                  : "$iexit_thunk$cdecl$");
1017a6dacacSDimitry Andric 
1027a6dacacSDimitry Andric   Type *Arm64RetTy;
1037a6dacacSDimitry Andric   Type *X64RetTy;
1047a6dacacSDimitry Andric 
1057a6dacacSDimitry Andric   SmallVector<Type *> Arm64ArgTypes;
1067a6dacacSDimitry Andric   SmallVector<Type *> X64ArgTypes;
1077a6dacacSDimitry Andric 
1087a6dacacSDimitry Andric   // The first argument to a thunk is the called function, stored in x9.
1097a6dacacSDimitry Andric   // For exit thunks, we pass the called function down to the emulator;
110439352acSDimitry Andric   // for entry/guest exit thunks, we just call the Arm64 function directly.
111439352acSDimitry Andric   if (TT == ThunkType::Exit)
1127a6dacacSDimitry Andric     Arm64ArgTypes.push_back(PtrTy);
1137a6dacacSDimitry Andric   X64ArgTypes.push_back(PtrTy);
1147a6dacacSDimitry Andric 
1157a6dacacSDimitry Andric   bool HasSretPtr = false;
1167a6dacacSDimitry Andric   getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
1177a6dacacSDimitry Andric                   X64ArgTypes, HasSretPtr);
1187a6dacacSDimitry Andric 
119439352acSDimitry Andric   getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,
120439352acSDimitry Andric                    HasSretPtr);
1217a6dacacSDimitry Andric 
1227a6dacacSDimitry Andric   Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false);
123439352acSDimitry Andric 
1247a6dacacSDimitry Andric   X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false);
1257a6dacacSDimitry Andric }
1267a6dacacSDimitry Andric 
1277a6dacacSDimitry Andric void AArch64Arm64ECCallLowering::getThunkArgTypes(
128439352acSDimitry Andric     FunctionType *FT, AttributeList AttrList, ThunkType TT, raw_ostream &Out,
1297a6dacacSDimitry Andric     SmallVectorImpl<Type *> &Arm64ArgTypes,
1307a6dacacSDimitry Andric     SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr) {
1317a6dacacSDimitry Andric 
1327a6dacacSDimitry Andric   Out << "$";
1337a6dacacSDimitry Andric   if (FT->isVarArg()) {
1347a6dacacSDimitry Andric     // We treat the variadic function's thunk as a normal function
1357a6dacacSDimitry Andric     // with the following type on the ARM side:
1367a6dacacSDimitry Andric     //   rettype exitthunk(
1377a6dacacSDimitry Andric     //     ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5)
1387a6dacacSDimitry Andric     //
1397a6dacacSDimitry Andric     // that can coverage all types of variadic function.
1407a6dacacSDimitry Andric     // x9 is similar to normal exit thunk, store the called function.
1417a6dacacSDimitry Andric     // x0-x3 is the arguments be stored in registers.
1427a6dacacSDimitry Andric     // x4 is the address of the arguments on the stack.
1437a6dacacSDimitry Andric     // x5 is the size of the arguments on the stack.
1447a6dacacSDimitry Andric     //
1457a6dacacSDimitry Andric     // On the x64 side, it's the same except that x5 isn't set.
1467a6dacacSDimitry Andric     //
1477a6dacacSDimitry Andric     // If both the ARM and X64 sides are sret, there are only three
1487a6dacacSDimitry Andric     // arguments in registers.
1497a6dacacSDimitry Andric     //
1507a6dacacSDimitry Andric     // If the X64 side is sret, but the ARM side isn't, we pass an extra value
1517a6dacacSDimitry Andric     // to/from the X64 side, and let SelectionDAG transform it into a memory
1527a6dacacSDimitry Andric     // location.
1537a6dacacSDimitry Andric     Out << "varargs";
1547a6dacacSDimitry Andric 
1557a6dacacSDimitry Andric     // x0-x3
1567a6dacacSDimitry Andric     for (int i = HasSretPtr ? 1 : 0; i < 4; i++) {
1577a6dacacSDimitry Andric       Arm64ArgTypes.push_back(I64Ty);
1587a6dacacSDimitry Andric       X64ArgTypes.push_back(I64Ty);
1597a6dacacSDimitry Andric     }
1607a6dacacSDimitry Andric 
1617a6dacacSDimitry Andric     // x4
1627a6dacacSDimitry Andric     Arm64ArgTypes.push_back(PtrTy);
1637a6dacacSDimitry Andric     X64ArgTypes.push_back(PtrTy);
1647a6dacacSDimitry Andric     // x5
1657a6dacacSDimitry Andric     Arm64ArgTypes.push_back(I64Ty);
166439352acSDimitry Andric     if (TT != ThunkType::Entry) {
167439352acSDimitry Andric       // FIXME: x5 isn't actually used by the x64 side; revisit once we
1687a6dacacSDimitry Andric       // have proper isel for varargs
1697a6dacacSDimitry Andric       X64ArgTypes.push_back(I64Ty);
170439352acSDimitry Andric     }
1717a6dacacSDimitry Andric     return;
1727a6dacacSDimitry Andric   }
1737a6dacacSDimitry Andric 
1747a6dacacSDimitry Andric   unsigned I = 0;
1757a6dacacSDimitry Andric   if (HasSretPtr)
1767a6dacacSDimitry Andric     I++;
1777a6dacacSDimitry Andric 
1787a6dacacSDimitry Andric   if (I == FT->getNumParams()) {
1797a6dacacSDimitry Andric     Out << "v";
1807a6dacacSDimitry Andric     return;
1817a6dacacSDimitry Andric   }
1827a6dacacSDimitry Andric 
1837a6dacacSDimitry Andric   for (unsigned E = FT->getNumParams(); I != E; ++I) {
1847a6dacacSDimitry Andric #if 0
1857a6dacacSDimitry Andric     // FIXME: Need more information about argument size; see
1867a6dacacSDimitry Andric     // https://reviews.llvm.org/D132926
1877a6dacacSDimitry Andric     uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I);
188*3a079333SDimitry Andric     Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne();
1897a6dacacSDimitry Andric #else
1907a6dacacSDimitry Andric     uint64_t ArgSizeBytes = 0;
191*3a079333SDimitry Andric     Align ParamAlign = Align();
1927a6dacacSDimitry Andric #endif
1937a6dacacSDimitry Andric     Type *Arm64Ty, *X64Ty;
1947a6dacacSDimitry Andric     canonicalizeThunkType(FT->getParamType(I), ParamAlign,
1957a6dacacSDimitry Andric                           /*Ret*/ false, ArgSizeBytes, Out, Arm64Ty, X64Ty);
1967a6dacacSDimitry Andric     Arm64ArgTypes.push_back(Arm64Ty);
1977a6dacacSDimitry Andric     X64ArgTypes.push_back(X64Ty);
1987a6dacacSDimitry Andric   }
1997a6dacacSDimitry Andric }
2007a6dacacSDimitry Andric 
2017a6dacacSDimitry Andric void AArch64Arm64ECCallLowering::getThunkRetType(
2027a6dacacSDimitry Andric     FunctionType *FT, AttributeList AttrList, raw_ostream &Out,
2037a6dacacSDimitry Andric     Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes,
2047a6dacacSDimitry Andric     SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr) {
2057a6dacacSDimitry Andric   Type *T = FT->getReturnType();
2067a6dacacSDimitry Andric #if 0
2077a6dacacSDimitry Andric   // FIXME: Need more information about argument size; see
2087a6dacacSDimitry Andric   // https://reviews.llvm.org/D132926
2097a6dacacSDimitry Andric   uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes();
2107a6dacacSDimitry Andric #else
2117a6dacacSDimitry Andric   int64_t ArgSizeBytes = 0;
2127a6dacacSDimitry Andric #endif
2137a6dacacSDimitry Andric   if (T->isVoidTy()) {
2147a6dacacSDimitry Andric     if (FT->getNumParams()) {
2157a6dacacSDimitry Andric       auto SRetAttr = AttrList.getParamAttr(0, Attribute::StructRet);
2167a6dacacSDimitry Andric       auto InRegAttr = AttrList.getParamAttr(0, Attribute::InReg);
2177a6dacacSDimitry Andric       if (SRetAttr.isValid() && InRegAttr.isValid()) {
2187a6dacacSDimitry Andric         // sret+inreg indicates a call that returns a C++ class value. This is
2197a6dacacSDimitry Andric         // actually equivalent to just passing and returning a void* pointer
2207a6dacacSDimitry Andric         // as the first argument. Translate it that way, instead of trying
2217a6dacacSDimitry Andric         // to model "inreg" in the thunk's calling convention, to simplify
2227a6dacacSDimitry Andric         // the rest of the code.
2237a6dacacSDimitry Andric         Out << "i8";
2247a6dacacSDimitry Andric         Arm64RetTy = I64Ty;
2257a6dacacSDimitry Andric         X64RetTy = I64Ty;
2267a6dacacSDimitry Andric         return;
2277a6dacacSDimitry Andric       }
2287a6dacacSDimitry Andric       if (SRetAttr.isValid()) {
2297a6dacacSDimitry Andric         // FIXME: Sanity-check the sret type; if it's an integer or pointer,
2307a6dacacSDimitry Andric         // we'll get screwy mangling/codegen.
2317a6dacacSDimitry Andric         // FIXME: For large struct types, mangle as an integer argument and
2327a6dacacSDimitry Andric         // integer return, so we can reuse more thunks, instead of "m" syntax.
2337a6dacacSDimitry Andric         // (MSVC mangles this case as an integer return with no argument, but
2347a6dacacSDimitry Andric         // that's a miscompile.)
2357a6dacacSDimitry Andric         Type *SRetType = SRetAttr.getValueAsType();
2367a6dacacSDimitry Andric         Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne();
2377a6dacacSDimitry Andric         Type *Arm64Ty, *X64Ty;
2387a6dacacSDimitry Andric         canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes,
2397a6dacacSDimitry Andric                               Out, Arm64Ty, X64Ty);
2407a6dacacSDimitry Andric         Arm64RetTy = VoidTy;
2417a6dacacSDimitry Andric         X64RetTy = VoidTy;
2427a6dacacSDimitry Andric         Arm64ArgTypes.push_back(FT->getParamType(0));
2437a6dacacSDimitry Andric         X64ArgTypes.push_back(FT->getParamType(0));
2447a6dacacSDimitry Andric         HasSretPtr = true;
2457a6dacacSDimitry Andric         return;
2467a6dacacSDimitry Andric       }
2477a6dacacSDimitry Andric     }
2487a6dacacSDimitry Andric 
2497a6dacacSDimitry Andric     Out << "v";
2507a6dacacSDimitry Andric     Arm64RetTy = VoidTy;
2517a6dacacSDimitry Andric     X64RetTy = VoidTy;
2527a6dacacSDimitry Andric     return;
2537a6dacacSDimitry Andric   }
2547a6dacacSDimitry Andric 
2557a6dacacSDimitry Andric   canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out, Arm64RetTy,
2567a6dacacSDimitry Andric                         X64RetTy);
2577a6dacacSDimitry Andric   if (X64RetTy->isPointerTy()) {
2587a6dacacSDimitry Andric     // If the X64 type is canonicalized to a pointer, that means it's
2597a6dacacSDimitry Andric     // passed/returned indirectly. For a return value, that means it's an
2607a6dacacSDimitry Andric     // sret pointer.
2617a6dacacSDimitry Andric     X64ArgTypes.push_back(X64RetTy);
2627a6dacacSDimitry Andric     X64RetTy = VoidTy;
2637a6dacacSDimitry Andric   }
2647a6dacacSDimitry Andric }
2657a6dacacSDimitry Andric 
2667a6dacacSDimitry Andric void AArch64Arm64ECCallLowering::canonicalizeThunkType(
2677a6dacacSDimitry Andric     Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes, raw_ostream &Out,
2687a6dacacSDimitry Andric     Type *&Arm64Ty, Type *&X64Ty) {
2697a6dacacSDimitry Andric   if (T->isFloatTy()) {
2707a6dacacSDimitry Andric     Out << "f";
2717a6dacacSDimitry Andric     Arm64Ty = T;
2727a6dacacSDimitry Andric     X64Ty = T;
2737a6dacacSDimitry Andric     return;
2747a6dacacSDimitry Andric   }
2757a6dacacSDimitry Andric 
2767a6dacacSDimitry Andric   if (T->isDoubleTy()) {
2777a6dacacSDimitry Andric     Out << "d";
2787a6dacacSDimitry Andric     Arm64Ty = T;
2797a6dacacSDimitry Andric     X64Ty = T;
2807a6dacacSDimitry Andric     return;
2817a6dacacSDimitry Andric   }
2827a6dacacSDimitry Andric 
2837a6dacacSDimitry Andric   if (T->isFloatingPointTy()) {
2847a6dacacSDimitry Andric     report_fatal_error(
2857a6dacacSDimitry Andric         "Only 32 and 64 bit floating points are supported for ARM64EC thunks");
2867a6dacacSDimitry Andric   }
2877a6dacacSDimitry Andric 
2887a6dacacSDimitry Andric   auto &DL = M->getDataLayout();
2897a6dacacSDimitry Andric 
2907a6dacacSDimitry Andric   if (auto *StructTy = dyn_cast<StructType>(T))
2917a6dacacSDimitry Andric     if (StructTy->getNumElements() == 1)
2927a6dacacSDimitry Andric       T = StructTy->getElementType(0);
2937a6dacacSDimitry Andric 
2947a6dacacSDimitry Andric   if (T->isArrayTy()) {
2957a6dacacSDimitry Andric     Type *ElementTy = T->getArrayElementType();
2967a6dacacSDimitry Andric     uint64_t ElementCnt = T->getArrayNumElements();
2977a6dacacSDimitry Andric     uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8;
2987a6dacacSDimitry Andric     uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes;
2997a6dacacSDimitry Andric     if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) {
3007a6dacacSDimitry Andric       Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes;
301*3a079333SDimitry Andric       if (Alignment.value() >= 16 && !Ret)
3027a6dacacSDimitry Andric         Out << "a" << Alignment.value();
3037a6dacacSDimitry Andric       Arm64Ty = T;
3047a6dacacSDimitry Andric       if (TotalSizeBytes <= 8) {
3057a6dacacSDimitry Andric         // Arm64 returns small structs of float/double in float registers;
3067a6dacacSDimitry Andric         // X64 uses RAX.
3077a6dacacSDimitry Andric         X64Ty = llvm::Type::getIntNTy(M->getContext(), TotalSizeBytes * 8);
3087a6dacacSDimitry Andric       } else {
3097a6dacacSDimitry Andric         // Struct is passed directly on Arm64, but indirectly on X64.
3107a6dacacSDimitry Andric         X64Ty = PtrTy;
3117a6dacacSDimitry Andric       }
3127a6dacacSDimitry Andric       return;
3137a6dacacSDimitry Andric     } else if (T->isFloatingPointTy()) {
3147a6dacacSDimitry Andric       report_fatal_error("Only 32 and 64 bit floating points are supported for "
3157a6dacacSDimitry Andric                          "ARM64EC thunks");
3167a6dacacSDimitry Andric     }
3177a6dacacSDimitry Andric   }
3187a6dacacSDimitry Andric 
3197a6dacacSDimitry Andric   if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) {
3207a6dacacSDimitry Andric     Out << "i8";
3217a6dacacSDimitry Andric     Arm64Ty = I64Ty;
3227a6dacacSDimitry Andric     X64Ty = I64Ty;
3237a6dacacSDimitry Andric     return;
3247a6dacacSDimitry Andric   }
3257a6dacacSDimitry Andric 
3267a6dacacSDimitry Andric   unsigned TypeSize = ArgSizeBytes;
3277a6dacacSDimitry Andric   if (TypeSize == 0)
3287a6dacacSDimitry Andric     TypeSize = DL.getTypeSizeInBits(T) / 8;
3297a6dacacSDimitry Andric   Out << "m";
3307a6dacacSDimitry Andric   if (TypeSize != 4)
3317a6dacacSDimitry Andric     Out << TypeSize;
332*3a079333SDimitry Andric   if (Alignment.value() >= 16 && !Ret)
3337a6dacacSDimitry Andric     Out << "a" << Alignment.value();
3347a6dacacSDimitry Andric   // FIXME: Try to canonicalize Arm64Ty more thoroughly?
3357a6dacacSDimitry Andric   Arm64Ty = T;
3367a6dacacSDimitry Andric   if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) {
3377a6dacacSDimitry Andric     // Pass directly in an integer register
3387a6dacacSDimitry Andric     X64Ty = llvm::Type::getIntNTy(M->getContext(), TypeSize * 8);
3397a6dacacSDimitry Andric   } else {
3407a6dacacSDimitry Andric     // Passed directly on Arm64, but indirectly on X64.
3417a6dacacSDimitry Andric     X64Ty = PtrTy;
3427a6dacacSDimitry Andric   }
3437a6dacacSDimitry Andric }
3447a6dacacSDimitry Andric 
3457a6dacacSDimitry Andric // This function builds the "exit thunk", a function which translates
3467a6dacacSDimitry Andric // arguments and return values when calling x64 code from AArch64 code.
3477a6dacacSDimitry Andric Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
3487a6dacacSDimitry Andric                                                      AttributeList Attrs) {
3497a6dacacSDimitry Andric   SmallString<256> ExitThunkName;
3507a6dacacSDimitry Andric   llvm::raw_svector_ostream ExitThunkStream(ExitThunkName);
3517a6dacacSDimitry Andric   FunctionType *Arm64Ty, *X64Ty;
352439352acSDimitry Andric   getThunkType(FT, Attrs, ThunkType::Exit, ExitThunkStream, Arm64Ty, X64Ty);
3537a6dacacSDimitry Andric   if (Function *F = M->getFunction(ExitThunkName))
3547a6dacacSDimitry Andric     return F;
3557a6dacacSDimitry Andric 
3567a6dacacSDimitry Andric   Function *F = Function::Create(Arm64Ty, GlobalValue::LinkOnceODRLinkage, 0,
3577a6dacacSDimitry Andric                                  ExitThunkName, M);
3587a6dacacSDimitry Andric   F->setCallingConv(CallingConv::ARM64EC_Thunk_Native);
3597a6dacacSDimitry Andric   F->setSection(".wowthk$aa");
3607a6dacacSDimitry Andric   F->setComdat(M->getOrInsertComdat(ExitThunkName));
3617a6dacacSDimitry Andric   // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
3627a6dacacSDimitry Andric   F->addFnAttr("frame-pointer", "all");
3637a6dacacSDimitry Andric   // Only copy sret from the first argument. For C++ instance methods, clang can
3647a6dacacSDimitry Andric   // stick an sret marking on a later argument, but it doesn't actually affect
3657a6dacacSDimitry Andric   // the ABI, so we can omit it. This avoids triggering a verifier assertion.
3667a6dacacSDimitry Andric   if (FT->getNumParams()) {
3677a6dacacSDimitry Andric     auto SRet = Attrs.getParamAttr(0, Attribute::StructRet);
3687a6dacacSDimitry Andric     auto InReg = Attrs.getParamAttr(0, Attribute::InReg);
3697a6dacacSDimitry Andric     if (SRet.isValid() && !InReg.isValid())
3707a6dacacSDimitry Andric       F->addParamAttr(1, SRet);
3717a6dacacSDimitry Andric   }
3727a6dacacSDimitry Andric   // FIXME: Copy anything other than sret?  Shouldn't be necessary for normal
3737a6dacacSDimitry Andric   // C ABI, but might show up in other cases.
3747a6dacacSDimitry Andric   BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F);
3757a6dacacSDimitry Andric   IRBuilder<> IRB(BB);
3767a6dacacSDimitry Andric   Value *CalleePtr =
3777a6dacacSDimitry Andric       M->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy);
3787a6dacacSDimitry Andric   Value *Callee = IRB.CreateLoad(PtrTy, CalleePtr);
3797a6dacacSDimitry Andric   auto &DL = M->getDataLayout();
3807a6dacacSDimitry Andric   SmallVector<Value *> Args;
3817a6dacacSDimitry Andric 
3827a6dacacSDimitry Andric   // Pass the called function in x9.
3837a6dacacSDimitry Andric   Args.push_back(F->arg_begin());
3847a6dacacSDimitry Andric 
3857a6dacacSDimitry Andric   Type *RetTy = Arm64Ty->getReturnType();
3867a6dacacSDimitry Andric   if (RetTy != X64Ty->getReturnType()) {
3877a6dacacSDimitry Andric     // If the return type is an array or struct, translate it. Values of size
3887a6dacacSDimitry Andric     // 8 or less go into RAX; bigger values go into memory, and we pass a
3897a6dacacSDimitry Andric     // pointer.
3907a6dacacSDimitry Andric     if (DL.getTypeStoreSize(RetTy) > 8) {
3917a6dacacSDimitry Andric       Args.push_back(IRB.CreateAlloca(RetTy));
3927a6dacacSDimitry Andric     }
3937a6dacacSDimitry Andric   }
3947a6dacacSDimitry Andric 
3957a6dacacSDimitry Andric   for (auto &Arg : make_range(F->arg_begin() + 1, F->arg_end())) {
3967a6dacacSDimitry Andric     // Translate arguments from AArch64 calling convention to x86 calling
3977a6dacacSDimitry Andric     // convention.
3987a6dacacSDimitry Andric     //
3997a6dacacSDimitry Andric     // For simple types, we don't need to do any translation: they're
4007a6dacacSDimitry Andric     // represented the same way. (Implicit sign extension is not part of
4017a6dacacSDimitry Andric     // either convention.)
4027a6dacacSDimitry Andric     //
4037a6dacacSDimitry Andric     // The big thing we have to worry about is struct types... but
4047a6dacacSDimitry Andric     // fortunately AArch64 clang is pretty friendly here: the cases that need
4057a6dacacSDimitry Andric     // translation are always passed as a struct or array. (If we run into
4067a6dacacSDimitry Andric     // some cases where this doesn't work, we can teach clang to mark it up
4077a6dacacSDimitry Andric     // with an attribute.)
4087a6dacacSDimitry Andric     //
4097a6dacacSDimitry Andric     // The first argument is the called function, stored in x9.
4107a6dacacSDimitry Andric     if (Arg.getType()->isArrayTy() || Arg.getType()->isStructTy() ||
4117a6dacacSDimitry Andric         DL.getTypeStoreSize(Arg.getType()) > 8) {
4127a6dacacSDimitry Andric       Value *Mem = IRB.CreateAlloca(Arg.getType());
4137a6dacacSDimitry Andric       IRB.CreateStore(&Arg, Mem);
4147a6dacacSDimitry Andric       if (DL.getTypeStoreSize(Arg.getType()) <= 8) {
4157a6dacacSDimitry Andric         Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType()));
4167a6dacacSDimitry Andric         Args.push_back(IRB.CreateLoad(IntTy, IRB.CreateBitCast(Mem, PtrTy)));
4177a6dacacSDimitry Andric       } else
4187a6dacacSDimitry Andric         Args.push_back(Mem);
4197a6dacacSDimitry Andric     } else {
4207a6dacacSDimitry Andric       Args.push_back(&Arg);
4217a6dacacSDimitry Andric     }
4227a6dacacSDimitry Andric   }
4237a6dacacSDimitry Andric   // FIXME: Transfer necessary attributes? sret? anything else?
4247a6dacacSDimitry Andric 
4257a6dacacSDimitry Andric   Callee = IRB.CreateBitCast(Callee, PtrTy);
4267a6dacacSDimitry Andric   CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args);
4277a6dacacSDimitry Andric   Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
4287a6dacacSDimitry Andric 
4297a6dacacSDimitry Andric   Value *RetVal = Call;
4307a6dacacSDimitry Andric   if (RetTy != X64Ty->getReturnType()) {
4317a6dacacSDimitry Andric     // If we rewrote the return type earlier, convert the return value to
4327a6dacacSDimitry Andric     // the proper type.
4337a6dacacSDimitry Andric     if (DL.getTypeStoreSize(RetTy) > 8) {
4347a6dacacSDimitry Andric       RetVal = IRB.CreateLoad(RetTy, Args[1]);
4357a6dacacSDimitry Andric     } else {
4367a6dacacSDimitry Andric       Value *CastAlloca = IRB.CreateAlloca(RetTy);
4377a6dacacSDimitry Andric       IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
4387a6dacacSDimitry Andric       RetVal = IRB.CreateLoad(RetTy, CastAlloca);
4397a6dacacSDimitry Andric     }
4407a6dacacSDimitry Andric   }
4417a6dacacSDimitry Andric 
4427a6dacacSDimitry Andric   if (RetTy->isVoidTy())
4437a6dacacSDimitry Andric     IRB.CreateRetVoid();
4447a6dacacSDimitry Andric   else
4457a6dacacSDimitry Andric     IRB.CreateRet(RetVal);
4467a6dacacSDimitry Andric   return F;
4477a6dacacSDimitry Andric }
4487a6dacacSDimitry Andric 
4497a6dacacSDimitry Andric // This function builds the "entry thunk", a function which translates
4507a6dacacSDimitry Andric // arguments and return values when calling AArch64 code from x64 code.
4517a6dacacSDimitry Andric Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
4527a6dacacSDimitry Andric   SmallString<256> EntryThunkName;
4537a6dacacSDimitry Andric   llvm::raw_svector_ostream EntryThunkStream(EntryThunkName);
4547a6dacacSDimitry Andric   FunctionType *Arm64Ty, *X64Ty;
455439352acSDimitry Andric   getThunkType(F->getFunctionType(), F->getAttributes(), ThunkType::Entry,
4567a6dacacSDimitry Andric                EntryThunkStream, Arm64Ty, X64Ty);
4577a6dacacSDimitry Andric   if (Function *F = M->getFunction(EntryThunkName))
4587a6dacacSDimitry Andric     return F;
4597a6dacacSDimitry Andric 
4607a6dacacSDimitry Andric   Function *Thunk = Function::Create(X64Ty, GlobalValue::LinkOnceODRLinkage, 0,
4617a6dacacSDimitry Andric                                      EntryThunkName, M);
4627a6dacacSDimitry Andric   Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
4637a6dacacSDimitry Andric   Thunk->setSection(".wowthk$aa");
4647a6dacacSDimitry Andric   Thunk->setComdat(M->getOrInsertComdat(EntryThunkName));
4657a6dacacSDimitry Andric   // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
4667a6dacacSDimitry Andric   Thunk->addFnAttr("frame-pointer", "all");
4677a6dacacSDimitry Andric 
4687a6dacacSDimitry Andric   auto &DL = M->getDataLayout();
4697a6dacacSDimitry Andric   BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk);
4707a6dacacSDimitry Andric   IRBuilder<> IRB(BB);
4717a6dacacSDimitry Andric 
4727a6dacacSDimitry Andric   Type *RetTy = Arm64Ty->getReturnType();
4737a6dacacSDimitry Andric   Type *X64RetType = X64Ty->getReturnType();
4747a6dacacSDimitry Andric 
4757a6dacacSDimitry Andric   bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy();
4767a6dacacSDimitry Andric   unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
477439352acSDimitry Andric   unsigned PassthroughArgSize = F->isVarArg() ? 5 : Thunk->arg_size();
4787a6dacacSDimitry Andric 
4797a6dacacSDimitry Andric   // Translate arguments to call.
4807a6dacacSDimitry Andric   SmallVector<Value *> Args;
481439352acSDimitry Andric   for (unsigned i = ThunkArgOffset, e = PassthroughArgSize; i != e; ++i) {
4827a6dacacSDimitry Andric     Value *Arg = Thunk->getArg(i);
4837a6dacacSDimitry Andric     Type *ArgTy = Arm64Ty->getParamType(i - ThunkArgOffset);
4847a6dacacSDimitry Andric     if (ArgTy->isArrayTy() || ArgTy->isStructTy() ||
4857a6dacacSDimitry Andric         DL.getTypeStoreSize(ArgTy) > 8) {
4867a6dacacSDimitry Andric       // Translate array/struct arguments to the expected type.
4877a6dacacSDimitry Andric       if (DL.getTypeStoreSize(ArgTy) <= 8) {
4887a6dacacSDimitry Andric         Value *CastAlloca = IRB.CreateAlloca(ArgTy);
4897a6dacacSDimitry Andric         IRB.CreateStore(Arg, IRB.CreateBitCast(CastAlloca, PtrTy));
4907a6dacacSDimitry Andric         Arg = IRB.CreateLoad(ArgTy, CastAlloca);
4917a6dacacSDimitry Andric       } else {
4927a6dacacSDimitry Andric         Arg = IRB.CreateLoad(ArgTy, IRB.CreateBitCast(Arg, PtrTy));
4937a6dacacSDimitry Andric       }
4947a6dacacSDimitry Andric     }
4957a6dacacSDimitry Andric     Args.push_back(Arg);
4967a6dacacSDimitry Andric   }
4977a6dacacSDimitry Andric 
498439352acSDimitry Andric   if (F->isVarArg()) {
499439352acSDimitry Andric     // The 5th argument to variadic entry thunks is used to model the x64 sp
500439352acSDimitry Andric     // which is passed to the thunk in x4, this can be passed to the callee as
501439352acSDimitry Andric     // the variadic argument start address after skipping over the 32 byte
502439352acSDimitry Andric     // shadow store.
503439352acSDimitry Andric 
504439352acSDimitry Andric     // The EC thunk CC will assign any argument marked as InReg to x4.
505439352acSDimitry Andric     Thunk->addParamAttr(5, Attribute::InReg);
506439352acSDimitry Andric     Value *Arg = Thunk->getArg(5);
507439352acSDimitry Andric     Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20));
508439352acSDimitry Andric     Args.push_back(Arg);
509439352acSDimitry Andric 
510439352acSDimitry Andric     // Pass in a zero variadic argument size (in x5).
511439352acSDimitry Andric     Args.push_back(IRB.getInt64(0));
512439352acSDimitry Andric   }
513439352acSDimitry Andric 
5147a6dacacSDimitry Andric   // Call the function passed to the thunk.
5157a6dacacSDimitry Andric   Value *Callee = Thunk->getArg(0);
5167a6dacacSDimitry Andric   Callee = IRB.CreateBitCast(Callee, PtrTy);
517*3a079333SDimitry Andric   CallInst *Call = IRB.CreateCall(Arm64Ty, Callee, Args);
518*3a079333SDimitry Andric 
519*3a079333SDimitry Andric   auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
520*3a079333SDimitry Andric   auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
521*3a079333SDimitry Andric   if (SRetAttr.isValid() && !InRegAttr.isValid()) {
522*3a079333SDimitry Andric     Thunk->addParamAttr(1, SRetAttr);
523*3a079333SDimitry Andric     Call->addParamAttr(0, SRetAttr);
524*3a079333SDimitry Andric   }
5257a6dacacSDimitry Andric 
5267a6dacacSDimitry Andric   Value *RetVal = Call;
5277a6dacacSDimitry Andric   if (TransformDirectToSRet) {
5287a6dacacSDimitry Andric     IRB.CreateStore(RetVal, IRB.CreateBitCast(Thunk->getArg(1), PtrTy));
5297a6dacacSDimitry Andric   } else if (X64RetType != RetTy) {
5307a6dacacSDimitry Andric     Value *CastAlloca = IRB.CreateAlloca(X64RetType);
5317a6dacacSDimitry Andric     IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
5327a6dacacSDimitry Andric     RetVal = IRB.CreateLoad(X64RetType, CastAlloca);
5337a6dacacSDimitry Andric   }
5347a6dacacSDimitry Andric 
5357a6dacacSDimitry Andric   // Return to the caller.  Note that the isel has code to translate this
5367a6dacacSDimitry Andric   // "ret" to a tail call to __os_arm64x_dispatch_ret.  (Alternatively, we
5377a6dacacSDimitry Andric   // could emit a tail call here, but that would require a dedicated calling
5387a6dacacSDimitry Andric   // convention, which seems more complicated overall.)
5397a6dacacSDimitry Andric   if (X64RetType->isVoidTy())
5407a6dacacSDimitry Andric     IRB.CreateRetVoid();
5417a6dacacSDimitry Andric   else
5427a6dacacSDimitry Andric     IRB.CreateRet(RetVal);
5437a6dacacSDimitry Andric 
5447a6dacacSDimitry Andric   return Thunk;
5457a6dacacSDimitry Andric }
5467a6dacacSDimitry Andric 
5477a6dacacSDimitry Andric // Builds the "guest exit thunk", a helper to call a function which may or may
5487a6dacacSDimitry Andric // not be an exit thunk. (We optimistically assume non-dllimport function
5497a6dacacSDimitry Andric // declarations refer to functions defined in AArch64 code; if the linker
5507a6dacacSDimitry Andric // can't prove that, we use this routine instead.)
5517a6dacacSDimitry Andric Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
5527a6dacacSDimitry Andric   llvm::raw_null_ostream NullThunkName;
5537a6dacacSDimitry Andric   FunctionType *Arm64Ty, *X64Ty;
554439352acSDimitry Andric   getThunkType(F->getFunctionType(), F->getAttributes(), ThunkType::GuestExit,
5557a6dacacSDimitry Andric                NullThunkName, Arm64Ty, X64Ty);
5567a6dacacSDimitry Andric   auto MangledName = getArm64ECMangledFunctionName(F->getName().str());
5577a6dacacSDimitry Andric   assert(MangledName && "Can't guest exit to function that's already native");
5587a6dacacSDimitry Andric   std::string ThunkName = *MangledName;
5597a6dacacSDimitry Andric   if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
5607a6dacacSDimitry Andric     ThunkName.insert(ThunkName.find("@"), "$exit_thunk");
5617a6dacacSDimitry Andric   } else {
5627a6dacacSDimitry Andric     ThunkName.append("$exit_thunk");
5637a6dacacSDimitry Andric   }
5647a6dacacSDimitry Andric   Function *GuestExit =
5657a6dacacSDimitry Andric       Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
5667a6dacacSDimitry Andric   GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
5677a6dacacSDimitry Andric   GuestExit->setSection(".wowthk$aa");
5687a6dacacSDimitry Andric   GuestExit->setMetadata(
5697a6dacacSDimitry Andric       "arm64ec_unmangled_name",
5707a6dacacSDimitry Andric       MDNode::get(M->getContext(),
5717a6dacacSDimitry Andric                   MDString::get(M->getContext(), F->getName())));
5727a6dacacSDimitry Andric   GuestExit->setMetadata(
5737a6dacacSDimitry Andric       "arm64ec_ecmangled_name",
5747a6dacacSDimitry Andric       MDNode::get(M->getContext(),
5757a6dacacSDimitry Andric                   MDString::get(M->getContext(), *MangledName)));
5767a6dacacSDimitry Andric   F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {}));
5777a6dacacSDimitry Andric   BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
5787a6dacacSDimitry Andric   IRBuilder<> B(BB);
5797a6dacacSDimitry Andric 
5807a6dacacSDimitry Andric   // Load the global symbol as a pointer to the check function.
5817a6dacacSDimitry Andric   Value *GuardFn;
5827a6dacacSDimitry Andric   if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf"))
5837a6dacacSDimitry Andric     GuardFn = GuardFnCFGlobal;
5847a6dacacSDimitry Andric   else
5857a6dacacSDimitry Andric     GuardFn = GuardFnGlobal;
5867a6dacacSDimitry Andric   LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
5877a6dacacSDimitry Andric 
5887a6dacacSDimitry Andric   // Create new call instruction. The CFGuard check should always be a call,
5897a6dacacSDimitry Andric   // even if the original CallBase is an Invoke or CallBr instruction.
5907a6dacacSDimitry Andric   Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes());
5917a6dacacSDimitry Andric   CallInst *GuardCheck = B.CreateCall(
5927a6dacacSDimitry Andric       GuardFnType, GuardCheckLoad,
5937a6dacacSDimitry Andric       {B.CreateBitCast(F, B.getPtrTy()), B.CreateBitCast(Thunk, B.getPtrTy())});
5947a6dacacSDimitry Andric 
5957a6dacacSDimitry Andric   // Ensure that the first argument is passed in the correct register.
5967a6dacacSDimitry Andric   GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
5977a6dacacSDimitry Andric 
5987a6dacacSDimitry Andric   Value *GuardRetVal = B.CreateBitCast(GuardCheck, PtrTy);
5997a6dacacSDimitry Andric   SmallVector<Value *> Args;
6007a6dacacSDimitry Andric   for (Argument &Arg : GuestExit->args())
6017a6dacacSDimitry Andric     Args.push_back(&Arg);
6027a6dacacSDimitry Andric   CallInst *Call = B.CreateCall(Arm64Ty, GuardRetVal, Args);
6037a6dacacSDimitry Andric   Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
6047a6dacacSDimitry Andric 
6057a6dacacSDimitry Andric   if (Call->getType()->isVoidTy())
6067a6dacacSDimitry Andric     B.CreateRetVoid();
6077a6dacacSDimitry Andric   else
6087a6dacacSDimitry Andric     B.CreateRet(Call);
6097a6dacacSDimitry Andric 
6107a6dacacSDimitry Andric   auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
6117a6dacacSDimitry Andric   auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
6127a6dacacSDimitry Andric   if (SRetAttr.isValid() && !InRegAttr.isValid()) {
6137a6dacacSDimitry Andric     GuestExit->addParamAttr(0, SRetAttr);
6147a6dacacSDimitry Andric     Call->addParamAttr(0, SRetAttr);
6157a6dacacSDimitry Andric   }
6167a6dacacSDimitry Andric 
6177a6dacacSDimitry Andric   return GuestExit;
6187a6dacacSDimitry Andric }
6197a6dacacSDimitry Andric 
6207a6dacacSDimitry Andric // Lower an indirect call with inline code.
6217a6dacacSDimitry Andric void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
6227a6dacacSDimitry Andric   assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() &&
6237a6dacacSDimitry Andric          "Only applicable for Windows targets");
6247a6dacacSDimitry Andric 
6257a6dacacSDimitry Andric   IRBuilder<> B(CB);
6267a6dacacSDimitry Andric   Value *CalledOperand = CB->getCalledOperand();
6277a6dacacSDimitry Andric 
6287a6dacacSDimitry Andric   // If the indirect call is called within catchpad or cleanuppad,
6297a6dacacSDimitry Andric   // we need to copy "funclet" bundle of the call.
6307a6dacacSDimitry Andric   SmallVector<llvm::OperandBundleDef, 1> Bundles;
6317a6dacacSDimitry Andric   if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet))
6327a6dacacSDimitry Andric     Bundles.push_back(OperandBundleDef(*Bundle));
6337a6dacacSDimitry Andric 
6347a6dacacSDimitry Andric   // Load the global symbol as a pointer to the check function.
6357a6dacacSDimitry Andric   Value *GuardFn;
6367a6dacacSDimitry Andric   if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf"))
6377a6dacacSDimitry Andric     GuardFn = GuardFnCFGlobal;
6387a6dacacSDimitry Andric   else
6397a6dacacSDimitry Andric     GuardFn = GuardFnGlobal;
6407a6dacacSDimitry Andric   LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
6417a6dacacSDimitry Andric 
6427a6dacacSDimitry Andric   // Create new call instruction. The CFGuard check should always be a call,
6437a6dacacSDimitry Andric   // even if the original CallBase is an Invoke or CallBr instruction.
6447a6dacacSDimitry Andric   Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes());
6457a6dacacSDimitry Andric   CallInst *GuardCheck =
6467a6dacacSDimitry Andric       B.CreateCall(GuardFnType, GuardCheckLoad,
6477a6dacacSDimitry Andric                    {B.CreateBitCast(CalledOperand, B.getPtrTy()),
6487a6dacacSDimitry Andric                     B.CreateBitCast(Thunk, B.getPtrTy())},
6497a6dacacSDimitry Andric                    Bundles);
6507a6dacacSDimitry Andric 
6517a6dacacSDimitry Andric   // Ensure that the first argument is passed in the correct register.
6527a6dacacSDimitry Andric   GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
6537a6dacacSDimitry Andric 
6547a6dacacSDimitry Andric   Value *GuardRetVal = B.CreateBitCast(GuardCheck, CalledOperand->getType());
6557a6dacacSDimitry Andric   CB->setCalledOperand(GuardRetVal);
6567a6dacacSDimitry Andric }
6577a6dacacSDimitry Andric 
6587a6dacacSDimitry Andric bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
6597a6dacacSDimitry Andric   if (!GenerateThunks)
6607a6dacacSDimitry Andric     return false;
6617a6dacacSDimitry Andric 
6627a6dacacSDimitry Andric   M = &Mod;
6637a6dacacSDimitry Andric 
6647a6dacacSDimitry Andric   // Check if this module has the cfguard flag and read its value.
6657a6dacacSDimitry Andric   if (auto *MD =
6667a6dacacSDimitry Andric           mdconst::extract_or_null<ConstantInt>(M->getModuleFlag("cfguard")))
6677a6dacacSDimitry Andric     cfguard_module_flag = MD->getZExtValue();
6687a6dacacSDimitry Andric 
6697a6dacacSDimitry Andric   PtrTy = PointerType::getUnqual(M->getContext());
6707a6dacacSDimitry Andric   I64Ty = Type::getInt64Ty(M->getContext());
6717a6dacacSDimitry Andric   VoidTy = Type::getVoidTy(M->getContext());
6727a6dacacSDimitry Andric 
6737a6dacacSDimitry Andric   GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
6747a6dacacSDimitry Andric   GuardFnPtrType = PointerType::get(GuardFnType, 0);
6757a6dacacSDimitry Andric   GuardFnCFGlobal =
6767a6dacacSDimitry Andric       M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType);
6777a6dacacSDimitry Andric   GuardFnGlobal =
6787a6dacacSDimitry Andric       M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType);
6797a6dacacSDimitry Andric 
6807a6dacacSDimitry Andric   SetVector<Function *> DirectCalledFns;
6817a6dacacSDimitry Andric   for (Function &F : Mod)
6827a6dacacSDimitry Andric     if (!F.isDeclaration() &&
6837a6dacacSDimitry Andric         F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
6847a6dacacSDimitry Andric         F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)
6857a6dacacSDimitry Andric       processFunction(F, DirectCalledFns);
6867a6dacacSDimitry Andric 
6877a6dacacSDimitry Andric   struct ThunkInfo {
6887a6dacacSDimitry Andric     Constant *Src;
6897a6dacacSDimitry Andric     Constant *Dst;
6907a6dacacSDimitry Andric     unsigned Kind;
6917a6dacacSDimitry Andric   };
6927a6dacacSDimitry Andric   SmallVector<ThunkInfo> ThunkMapping;
6937a6dacacSDimitry Andric   for (Function &F : Mod) {
6947a6dacacSDimitry Andric     if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) &&
6957a6dacacSDimitry Andric         F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
6967a6dacacSDimitry Andric         F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) {
6977a6dacacSDimitry Andric       if (!F.hasComdat())
6987a6dacacSDimitry Andric         F.setComdat(Mod.getOrInsertComdat(F.getName()));
6997a6dacacSDimitry Andric       ThunkMapping.push_back({&F, buildEntryThunk(&F), 1});
7007a6dacacSDimitry Andric     }
7017a6dacacSDimitry Andric   }
7027a6dacacSDimitry Andric   for (Function *F : DirectCalledFns) {
7037a6dacacSDimitry Andric     ThunkMapping.push_back(
7047a6dacacSDimitry Andric         {F, buildExitThunk(F->getFunctionType(), F->getAttributes()), 4});
7057a6dacacSDimitry Andric     if (!F->hasDLLImportStorageClass())
7067a6dacacSDimitry Andric       ThunkMapping.push_back({buildGuestExitThunk(F), F, 0});
7077a6dacacSDimitry Andric   }
7087a6dacacSDimitry Andric 
7097a6dacacSDimitry Andric   if (!ThunkMapping.empty()) {
7107a6dacacSDimitry Andric     SmallVector<Constant *> ThunkMappingArrayElems;
7117a6dacacSDimitry Andric     for (ThunkInfo &Thunk : ThunkMapping) {
7127a6dacacSDimitry Andric       ThunkMappingArrayElems.push_back(ConstantStruct::getAnon(
7137a6dacacSDimitry Andric           {ConstantExpr::getBitCast(Thunk.Src, PtrTy),
7147a6dacacSDimitry Andric            ConstantExpr::getBitCast(Thunk.Dst, PtrTy),
7157a6dacacSDimitry Andric            ConstantInt::get(M->getContext(), APInt(32, Thunk.Kind))}));
7167a6dacacSDimitry Andric     }
7177a6dacacSDimitry Andric     Constant *ThunkMappingArray = ConstantArray::get(
7187a6dacacSDimitry Andric         llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(),
7197a6dacacSDimitry Andric                              ThunkMappingArrayElems.size()),
7207a6dacacSDimitry Andric         ThunkMappingArrayElems);
7217a6dacacSDimitry Andric     new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false,
7227a6dacacSDimitry Andric                        GlobalValue::ExternalLinkage, ThunkMappingArray,
7237a6dacacSDimitry Andric                        "llvm.arm64ec.symbolmap");
7247a6dacacSDimitry Andric   }
7257a6dacacSDimitry Andric 
7267a6dacacSDimitry Andric   return true;
7277a6dacacSDimitry Andric }
7287a6dacacSDimitry Andric 
7297a6dacacSDimitry Andric bool AArch64Arm64ECCallLowering::processFunction(
7307a6dacacSDimitry Andric     Function &F, SetVector<Function *> &DirectCalledFns) {
7317a6dacacSDimitry Andric   SmallVector<CallBase *, 8> IndirectCalls;
7327a6dacacSDimitry Andric 
7337a6dacacSDimitry Andric   // For ARM64EC targets, a function definition's name is mangled differently
7347a6dacacSDimitry Andric   // from the normal symbol. We currently have no representation of this sort
7357a6dacacSDimitry Andric   // of symbol in IR, so we change the name to the mangled name, then store
7367a6dacacSDimitry Andric   // the unmangled name as metadata.  Later passes that need the unmangled
7377a6dacacSDimitry Andric   // name (emitting the definition) can grab it from the metadata.
7387a6dacacSDimitry Andric   //
7397a6dacacSDimitry Andric   // FIXME: Handle functions with weak linkage?
7407a6dacacSDimitry Andric   if (F.hasExternalLinkage() || F.hasWeakLinkage() || F.hasLinkOnceLinkage()) {
7417a6dacacSDimitry Andric     if (std::optional<std::string> MangledName =
7427a6dacacSDimitry Andric             getArm64ECMangledFunctionName(F.getName().str())) {
7437a6dacacSDimitry Andric       F.setMetadata("arm64ec_unmangled_name",
7447a6dacacSDimitry Andric                     MDNode::get(M->getContext(),
7457a6dacacSDimitry Andric                                 MDString::get(M->getContext(), F.getName())));
7467a6dacacSDimitry Andric       if (F.hasComdat() && F.getComdat()->getName() == F.getName()) {
7477a6dacacSDimitry Andric         Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value());
7487a6dacacSDimitry Andric         SmallVector<GlobalObject *> ComdatUsers =
7497a6dacacSDimitry Andric             to_vector(F.getComdat()->getUsers());
7507a6dacacSDimitry Andric         for (GlobalObject *User : ComdatUsers)
7517a6dacacSDimitry Andric           User->setComdat(MangledComdat);
7527a6dacacSDimitry Andric       }
7537a6dacacSDimitry Andric       F.setName(MangledName.value());
7547a6dacacSDimitry Andric     }
7557a6dacacSDimitry Andric   }
7567a6dacacSDimitry Andric 
7577a6dacacSDimitry Andric   // Iterate over the instructions to find all indirect call/invoke/callbr
7587a6dacacSDimitry Andric   // instructions. Make a separate list of pointers to indirect
7597a6dacacSDimitry Andric   // call/invoke/callbr instructions because the original instructions will be
7607a6dacacSDimitry Andric   // deleted as the checks are added.
7617a6dacacSDimitry Andric   for (BasicBlock &BB : F) {
7627a6dacacSDimitry Andric     for (Instruction &I : BB) {
7637a6dacacSDimitry Andric       auto *CB = dyn_cast<CallBase>(&I);
7647a6dacacSDimitry Andric       if (!CB || CB->getCallingConv() == CallingConv::ARM64EC_Thunk_X64 ||
7657a6dacacSDimitry Andric           CB->isInlineAsm())
7667a6dacacSDimitry Andric         continue;
7677a6dacacSDimitry Andric 
7687a6dacacSDimitry Andric       // We need to instrument any call that isn't directly calling an
7697a6dacacSDimitry Andric       // ARM64 function.
7707a6dacacSDimitry Andric       //
7717a6dacacSDimitry Andric       // FIXME: getCalledFunction() fails if there's a bitcast (e.g.
7727a6dacacSDimitry Andric       // unprototyped functions in C)
7737a6dacacSDimitry Andric       if (Function *F = CB->getCalledFunction()) {
7747a6dacacSDimitry Andric         if (!LowerDirectToIndirect || F->hasLocalLinkage() ||
7757a6dacacSDimitry Andric             F->isIntrinsic() || !F->isDeclaration())
7767a6dacacSDimitry Andric           continue;
7777a6dacacSDimitry Andric 
7787a6dacacSDimitry Andric         DirectCalledFns.insert(F);
7797a6dacacSDimitry Andric         continue;
7807a6dacacSDimitry Andric       }
7817a6dacacSDimitry Andric 
7827a6dacacSDimitry Andric       IndirectCalls.push_back(CB);
7837a6dacacSDimitry Andric       ++Arm64ECCallsLowered;
7847a6dacacSDimitry Andric     }
7857a6dacacSDimitry Andric   }
7867a6dacacSDimitry Andric 
7877a6dacacSDimitry Andric   if (IndirectCalls.empty())
7887a6dacacSDimitry Andric     return false;
7897a6dacacSDimitry Andric 
7907a6dacacSDimitry Andric   for (CallBase *CB : IndirectCalls)
7917a6dacacSDimitry Andric     lowerCall(CB);
7927a6dacacSDimitry Andric 
7937a6dacacSDimitry Andric   return true;
7947a6dacacSDimitry Andric }
7957a6dacacSDimitry Andric 
7967a6dacacSDimitry Andric char AArch64Arm64ECCallLowering::ID = 0;
7977a6dacacSDimitry Andric INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering",
7987a6dacacSDimitry Andric                 "AArch64Arm64ECCallLowering", false, false)
7997a6dacacSDimitry Andric 
8007a6dacacSDimitry Andric ModulePass *llvm::createAArch64Arm64ECCallLoweringPass() {
8017a6dacacSDimitry Andric   return new AArch64Arm64ECCallLowering;
8027a6dacacSDimitry Andric }
803