17a6dacacSDimitry Andric //===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===//
27a6dacacSDimitry Andric //
37a6dacacSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47a6dacacSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
57a6dacacSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67a6dacacSDimitry Andric //
77a6dacacSDimitry Andric //===----------------------------------------------------------------------===//
87a6dacacSDimitry Andric ///
97a6dacacSDimitry Andric /// \file
107a6dacacSDimitry Andric /// This file contains the IR transform to lower external or indirect calls for
117a6dacacSDimitry Andric /// the ARM64EC calling convention. Such calls must go through the runtime, so
127a6dacacSDimitry Andric /// we can translate the calling convention for calls into the emulator.
137a6dacacSDimitry Andric ///
147a6dacacSDimitry Andric /// This subsumes Control Flow Guard handling.
157a6dacacSDimitry Andric ///
167a6dacacSDimitry Andric //===----------------------------------------------------------------------===//
177a6dacacSDimitry Andric
187a6dacacSDimitry Andric #include "AArch64.h"
197a6dacacSDimitry Andric #include "llvm/ADT/SetVector.h"
207a6dacacSDimitry Andric #include "llvm/ADT/SmallString.h"
217a6dacacSDimitry Andric #include "llvm/ADT/SmallVector.h"
227a6dacacSDimitry Andric #include "llvm/ADT/Statistic.h"
237a6dacacSDimitry Andric #include "llvm/IR/CallingConv.h"
240fca6ea1SDimitry Andric #include "llvm/IR/GlobalAlias.h"
257a6dacacSDimitry Andric #include "llvm/IR/IRBuilder.h"
267a6dacacSDimitry Andric #include "llvm/IR/Instruction.h"
270fca6ea1SDimitry Andric #include "llvm/IR/Mangler.h"
280fca6ea1SDimitry Andric #include "llvm/IR/Module.h"
297a6dacacSDimitry Andric #include "llvm/InitializePasses.h"
30439352acSDimitry Andric #include "llvm/Object/COFF.h"
317a6dacacSDimitry Andric #include "llvm/Pass.h"
327a6dacacSDimitry Andric #include "llvm/Support/CommandLine.h"
337a6dacacSDimitry Andric #include "llvm/TargetParser/Triple.h"
347a6dacacSDimitry Andric
357a6dacacSDimitry Andric using namespace llvm;
360fca6ea1SDimitry Andric using namespace llvm::COFF;
377a6dacacSDimitry Andric
387a6dacacSDimitry Andric using OperandBundleDef = OperandBundleDefT<Value *>;
397a6dacacSDimitry Andric
407a6dacacSDimitry Andric #define DEBUG_TYPE "arm64eccalllowering"
417a6dacacSDimitry Andric
427a6dacacSDimitry Andric STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered");
437a6dacacSDimitry Andric
447a6dacacSDimitry Andric static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect",
457a6dacacSDimitry Andric cl::Hidden, cl::init(true));
467a6dacacSDimitry Andric static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden,
477a6dacacSDimitry Andric cl::init(true));
487a6dacacSDimitry Andric
497a6dacacSDimitry Andric namespace {
507a6dacacSDimitry Andric
510fca6ea1SDimitry Andric enum ThunkArgTranslation : uint8_t {
520fca6ea1SDimitry Andric Direct,
530fca6ea1SDimitry Andric Bitcast,
540fca6ea1SDimitry Andric PointerIndirection,
550fca6ea1SDimitry Andric };
560fca6ea1SDimitry Andric
570fca6ea1SDimitry Andric struct ThunkArgInfo {
580fca6ea1SDimitry Andric Type *Arm64Ty;
590fca6ea1SDimitry Andric Type *X64Ty;
600fca6ea1SDimitry Andric ThunkArgTranslation Translation;
610fca6ea1SDimitry Andric };
62439352acSDimitry Andric
637a6dacacSDimitry Andric class AArch64Arm64ECCallLowering : public ModulePass {
647a6dacacSDimitry Andric public:
657a6dacacSDimitry Andric static char ID;
AArch64Arm64ECCallLowering()667a6dacacSDimitry Andric AArch64Arm64ECCallLowering() : ModulePass(ID) {
677a6dacacSDimitry Andric initializeAArch64Arm64ECCallLoweringPass(*PassRegistry::getPassRegistry());
687a6dacacSDimitry Andric }
697a6dacacSDimitry Andric
707a6dacacSDimitry Andric Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs);
717a6dacacSDimitry Andric Function *buildEntryThunk(Function *F);
727a6dacacSDimitry Andric void lowerCall(CallBase *CB);
737a6dacacSDimitry Andric Function *buildGuestExitThunk(Function *F);
740fca6ea1SDimitry Andric Function *buildPatchableThunk(GlobalAlias *UnmangledAlias,
750fca6ea1SDimitry Andric GlobalAlias *MangledAlias);
760fca6ea1SDimitry Andric bool processFunction(Function &F, SetVector<GlobalValue *> &DirectCalledFns,
770fca6ea1SDimitry Andric DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap);
787a6dacacSDimitry Andric bool runOnModule(Module &M) override;
797a6dacacSDimitry Andric
807a6dacacSDimitry Andric private:
817a6dacacSDimitry Andric int cfguard_module_flag = 0;
827a6dacacSDimitry Andric FunctionType *GuardFnType = nullptr;
837a6dacacSDimitry Andric PointerType *GuardFnPtrType = nullptr;
840fca6ea1SDimitry Andric FunctionType *DispatchFnType = nullptr;
850fca6ea1SDimitry Andric PointerType *DispatchFnPtrType = nullptr;
867a6dacacSDimitry Andric Constant *GuardFnCFGlobal = nullptr;
877a6dacacSDimitry Andric Constant *GuardFnGlobal = nullptr;
880fca6ea1SDimitry Andric Constant *DispatchFnGlobal = nullptr;
897a6dacacSDimitry Andric Module *M = nullptr;
907a6dacacSDimitry Andric
917a6dacacSDimitry Andric Type *PtrTy;
927a6dacacSDimitry Andric Type *I64Ty;
937a6dacacSDimitry Andric Type *VoidTy;
947a6dacacSDimitry Andric
950fca6ea1SDimitry Andric void getThunkType(FunctionType *FT, AttributeList AttrList,
960fca6ea1SDimitry Andric Arm64ECThunkType TT, raw_ostream &Out,
970fca6ea1SDimitry Andric FunctionType *&Arm64Ty, FunctionType *&X64Ty,
980fca6ea1SDimitry Andric SmallVector<ThunkArgTranslation> &ArgTranslations);
997a6dacacSDimitry Andric void getThunkRetType(FunctionType *FT, AttributeList AttrList,
1007a6dacacSDimitry Andric raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy,
1017a6dacacSDimitry Andric SmallVectorImpl<Type *> &Arm64ArgTypes,
1020fca6ea1SDimitry Andric SmallVectorImpl<Type *> &X64ArgTypes,
1030fca6ea1SDimitry Andric SmallVector<ThunkArgTranslation> &ArgTranslations,
1040fca6ea1SDimitry Andric bool &HasSretPtr);
1050fca6ea1SDimitry Andric void getThunkArgTypes(FunctionType *FT, AttributeList AttrList,
1060fca6ea1SDimitry Andric Arm64ECThunkType TT, raw_ostream &Out,
1077a6dacacSDimitry Andric SmallVectorImpl<Type *> &Arm64ArgTypes,
1080fca6ea1SDimitry Andric SmallVectorImpl<Type *> &X64ArgTypes,
1090fca6ea1SDimitry Andric SmallVectorImpl<ThunkArgTranslation> &ArgTranslations,
1100fca6ea1SDimitry Andric bool HasSretPtr);
1110fca6ea1SDimitry Andric ThunkArgInfo canonicalizeThunkType(Type *T, Align Alignment, bool Ret,
1120fca6ea1SDimitry Andric uint64_t ArgSizeBytes, raw_ostream &Out);
1137a6dacacSDimitry Andric };
1147a6dacacSDimitry Andric
1157a6dacacSDimitry Andric } // end anonymous namespace
1167a6dacacSDimitry Andric
getThunkType(FunctionType * FT,AttributeList AttrList,Arm64ECThunkType TT,raw_ostream & Out,FunctionType * & Arm64Ty,FunctionType * & X64Ty,SmallVector<ThunkArgTranslation> & ArgTranslations)1170fca6ea1SDimitry Andric void AArch64Arm64ECCallLowering::getThunkType(
1180fca6ea1SDimitry Andric FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
1190fca6ea1SDimitry Andric raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty,
1200fca6ea1SDimitry Andric SmallVector<ThunkArgTranslation> &ArgTranslations) {
1210fca6ea1SDimitry Andric Out << (TT == Arm64ECThunkType::Entry ? "$ientry_thunk$cdecl$"
122439352acSDimitry Andric : "$iexit_thunk$cdecl$");
1237a6dacacSDimitry Andric
1247a6dacacSDimitry Andric Type *Arm64RetTy;
1257a6dacacSDimitry Andric Type *X64RetTy;
1267a6dacacSDimitry Andric
1277a6dacacSDimitry Andric SmallVector<Type *> Arm64ArgTypes;
1287a6dacacSDimitry Andric SmallVector<Type *> X64ArgTypes;
1297a6dacacSDimitry Andric
1307a6dacacSDimitry Andric // The first argument to a thunk is the called function, stored in x9.
1317a6dacacSDimitry Andric // For exit thunks, we pass the called function down to the emulator;
132439352acSDimitry Andric // for entry/guest exit thunks, we just call the Arm64 function directly.
1330fca6ea1SDimitry Andric if (TT == Arm64ECThunkType::Exit)
1347a6dacacSDimitry Andric Arm64ArgTypes.push_back(PtrTy);
1357a6dacacSDimitry Andric X64ArgTypes.push_back(PtrTy);
1367a6dacacSDimitry Andric
1377a6dacacSDimitry Andric bool HasSretPtr = false;
1387a6dacacSDimitry Andric getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
1390fca6ea1SDimitry Andric X64ArgTypes, ArgTranslations, HasSretPtr);
1407a6dacacSDimitry Andric
141439352acSDimitry Andric getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,
1420fca6ea1SDimitry Andric ArgTranslations, HasSretPtr);
1437a6dacacSDimitry Andric
1447a6dacacSDimitry Andric Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false);
145439352acSDimitry Andric
1467a6dacacSDimitry Andric X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false);
1477a6dacacSDimitry Andric }
1487a6dacacSDimitry Andric
getThunkArgTypes(FunctionType * FT,AttributeList AttrList,Arm64ECThunkType TT,raw_ostream & Out,SmallVectorImpl<Type * > & Arm64ArgTypes,SmallVectorImpl<Type * > & X64ArgTypes,SmallVectorImpl<ThunkArgTranslation> & ArgTranslations,bool HasSretPtr)1497a6dacacSDimitry Andric void AArch64Arm64ECCallLowering::getThunkArgTypes(
1500fca6ea1SDimitry Andric FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
1510fca6ea1SDimitry Andric raw_ostream &Out, SmallVectorImpl<Type *> &Arm64ArgTypes,
1520fca6ea1SDimitry Andric SmallVectorImpl<Type *> &X64ArgTypes,
1530fca6ea1SDimitry Andric SmallVectorImpl<ThunkArgTranslation> &ArgTranslations, bool HasSretPtr) {
1547a6dacacSDimitry Andric
1557a6dacacSDimitry Andric Out << "$";
1567a6dacacSDimitry Andric if (FT->isVarArg()) {
1577a6dacacSDimitry Andric // We treat the variadic function's thunk as a normal function
1587a6dacacSDimitry Andric // with the following type on the ARM side:
1597a6dacacSDimitry Andric // rettype exitthunk(
1607a6dacacSDimitry Andric // ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5)
1617a6dacacSDimitry Andric //
1627a6dacacSDimitry Andric // that can coverage all types of variadic function.
1637a6dacacSDimitry Andric // x9 is similar to normal exit thunk, store the called function.
1647a6dacacSDimitry Andric // x0-x3 is the arguments be stored in registers.
1657a6dacacSDimitry Andric // x4 is the address of the arguments on the stack.
1667a6dacacSDimitry Andric // x5 is the size of the arguments on the stack.
1677a6dacacSDimitry Andric //
1687a6dacacSDimitry Andric // On the x64 side, it's the same except that x5 isn't set.
1697a6dacacSDimitry Andric //
1707a6dacacSDimitry Andric // If both the ARM and X64 sides are sret, there are only three
1717a6dacacSDimitry Andric // arguments in registers.
1727a6dacacSDimitry Andric //
1737a6dacacSDimitry Andric // If the X64 side is sret, but the ARM side isn't, we pass an extra value
1747a6dacacSDimitry Andric // to/from the X64 side, and let SelectionDAG transform it into a memory
1757a6dacacSDimitry Andric // location.
1767a6dacacSDimitry Andric Out << "varargs";
1777a6dacacSDimitry Andric
1787a6dacacSDimitry Andric // x0-x3
1797a6dacacSDimitry Andric for (int i = HasSretPtr ? 1 : 0; i < 4; i++) {
1807a6dacacSDimitry Andric Arm64ArgTypes.push_back(I64Ty);
1817a6dacacSDimitry Andric X64ArgTypes.push_back(I64Ty);
1820fca6ea1SDimitry Andric ArgTranslations.push_back(ThunkArgTranslation::Direct);
1837a6dacacSDimitry Andric }
1847a6dacacSDimitry Andric
1857a6dacacSDimitry Andric // x4
1867a6dacacSDimitry Andric Arm64ArgTypes.push_back(PtrTy);
1877a6dacacSDimitry Andric X64ArgTypes.push_back(PtrTy);
1880fca6ea1SDimitry Andric ArgTranslations.push_back(ThunkArgTranslation::Direct);
1897a6dacacSDimitry Andric // x5
1907a6dacacSDimitry Andric Arm64ArgTypes.push_back(I64Ty);
1910fca6ea1SDimitry Andric if (TT != Arm64ECThunkType::Entry) {
192439352acSDimitry Andric // FIXME: x5 isn't actually used by the x64 side; revisit once we
1937a6dacacSDimitry Andric // have proper isel for varargs
1947a6dacacSDimitry Andric X64ArgTypes.push_back(I64Ty);
1950fca6ea1SDimitry Andric ArgTranslations.push_back(ThunkArgTranslation::Direct);
196439352acSDimitry Andric }
1977a6dacacSDimitry Andric return;
1987a6dacacSDimitry Andric }
1997a6dacacSDimitry Andric
2007a6dacacSDimitry Andric unsigned I = 0;
2017a6dacacSDimitry Andric if (HasSretPtr)
2027a6dacacSDimitry Andric I++;
2037a6dacacSDimitry Andric
2047a6dacacSDimitry Andric if (I == FT->getNumParams()) {
2057a6dacacSDimitry Andric Out << "v";
2067a6dacacSDimitry Andric return;
2077a6dacacSDimitry Andric }
2087a6dacacSDimitry Andric
2097a6dacacSDimitry Andric for (unsigned E = FT->getNumParams(); I != E; ++I) {
2107a6dacacSDimitry Andric #if 0
2117a6dacacSDimitry Andric // FIXME: Need more information about argument size; see
2127a6dacacSDimitry Andric // https://reviews.llvm.org/D132926
2137a6dacacSDimitry Andric uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I);
2143a079333SDimitry Andric Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne();
2157a6dacacSDimitry Andric #else
2167a6dacacSDimitry Andric uint64_t ArgSizeBytes = 0;
2173a079333SDimitry Andric Align ParamAlign = Align();
2187a6dacacSDimitry Andric #endif
2190fca6ea1SDimitry Andric auto [Arm64Ty, X64Ty, ArgTranslation] =
2207a6dacacSDimitry Andric canonicalizeThunkType(FT->getParamType(I), ParamAlign,
2210fca6ea1SDimitry Andric /*Ret*/ false, ArgSizeBytes, Out);
2227a6dacacSDimitry Andric Arm64ArgTypes.push_back(Arm64Ty);
2237a6dacacSDimitry Andric X64ArgTypes.push_back(X64Ty);
2240fca6ea1SDimitry Andric ArgTranslations.push_back(ArgTranslation);
2257a6dacacSDimitry Andric }
2267a6dacacSDimitry Andric }
2277a6dacacSDimitry Andric
getThunkRetType(FunctionType * FT,AttributeList AttrList,raw_ostream & Out,Type * & Arm64RetTy,Type * & X64RetTy,SmallVectorImpl<Type * > & Arm64ArgTypes,SmallVectorImpl<Type * > & X64ArgTypes,SmallVector<ThunkArgTranslation> & ArgTranslations,bool & HasSretPtr)2287a6dacacSDimitry Andric void AArch64Arm64ECCallLowering::getThunkRetType(
2297a6dacacSDimitry Andric FunctionType *FT, AttributeList AttrList, raw_ostream &Out,
2307a6dacacSDimitry Andric Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes,
2310fca6ea1SDimitry Andric SmallVectorImpl<Type *> &X64ArgTypes,
2320fca6ea1SDimitry Andric SmallVector<ThunkArgTranslation> &ArgTranslations, bool &HasSretPtr) {
2337a6dacacSDimitry Andric Type *T = FT->getReturnType();
2347a6dacacSDimitry Andric #if 0
2357a6dacacSDimitry Andric // FIXME: Need more information about argument size; see
2367a6dacacSDimitry Andric // https://reviews.llvm.org/D132926
2377a6dacacSDimitry Andric uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes();
2387a6dacacSDimitry Andric #else
2397a6dacacSDimitry Andric int64_t ArgSizeBytes = 0;
2407a6dacacSDimitry Andric #endif
2417a6dacacSDimitry Andric if (T->isVoidTy()) {
2427a6dacacSDimitry Andric if (FT->getNumParams()) {
2430fca6ea1SDimitry Andric Attribute SRetAttr0 = AttrList.getParamAttr(0, Attribute::StructRet);
2440fca6ea1SDimitry Andric Attribute InRegAttr0 = AttrList.getParamAttr(0, Attribute::InReg);
2450fca6ea1SDimitry Andric Attribute SRetAttr1, InRegAttr1;
2460fca6ea1SDimitry Andric if (FT->getNumParams() > 1) {
2470fca6ea1SDimitry Andric // Also check the second parameter (for class methods, the first
2480fca6ea1SDimitry Andric // parameter is "this", and the second parameter is the sret pointer.)
2490fca6ea1SDimitry Andric // It doesn't matter which one is sret.
2500fca6ea1SDimitry Andric SRetAttr1 = AttrList.getParamAttr(1, Attribute::StructRet);
2510fca6ea1SDimitry Andric InRegAttr1 = AttrList.getParamAttr(1, Attribute::InReg);
2520fca6ea1SDimitry Andric }
2530fca6ea1SDimitry Andric if ((SRetAttr0.isValid() && InRegAttr0.isValid()) ||
2540fca6ea1SDimitry Andric (SRetAttr1.isValid() && InRegAttr1.isValid())) {
2557a6dacacSDimitry Andric // sret+inreg indicates a call that returns a C++ class value. This is
2567a6dacacSDimitry Andric // actually equivalent to just passing and returning a void* pointer
2570fca6ea1SDimitry Andric // as the first or second argument. Translate it that way, instead of
2580fca6ea1SDimitry Andric // trying to model "inreg" in the thunk's calling convention; this
2590fca6ea1SDimitry Andric // simplfies the rest of the code, and matches MSVC mangling.
2607a6dacacSDimitry Andric Out << "i8";
2617a6dacacSDimitry Andric Arm64RetTy = I64Ty;
2627a6dacacSDimitry Andric X64RetTy = I64Ty;
2637a6dacacSDimitry Andric return;
2647a6dacacSDimitry Andric }
2650fca6ea1SDimitry Andric if (SRetAttr0.isValid()) {
2667a6dacacSDimitry Andric // FIXME: Sanity-check the sret type; if it's an integer or pointer,
2677a6dacacSDimitry Andric // we'll get screwy mangling/codegen.
2687a6dacacSDimitry Andric // FIXME: For large struct types, mangle as an integer argument and
2697a6dacacSDimitry Andric // integer return, so we can reuse more thunks, instead of "m" syntax.
2707a6dacacSDimitry Andric // (MSVC mangles this case as an integer return with no argument, but
2717a6dacacSDimitry Andric // that's a miscompile.)
2720fca6ea1SDimitry Andric Type *SRetType = SRetAttr0.getValueAsType();
2737a6dacacSDimitry Andric Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne();
2747a6dacacSDimitry Andric canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes,
2750fca6ea1SDimitry Andric Out);
2767a6dacacSDimitry Andric Arm64RetTy = VoidTy;
2777a6dacacSDimitry Andric X64RetTy = VoidTy;
2787a6dacacSDimitry Andric Arm64ArgTypes.push_back(FT->getParamType(0));
2797a6dacacSDimitry Andric X64ArgTypes.push_back(FT->getParamType(0));
2800fca6ea1SDimitry Andric ArgTranslations.push_back(ThunkArgTranslation::Direct);
2817a6dacacSDimitry Andric HasSretPtr = true;
2827a6dacacSDimitry Andric return;
2837a6dacacSDimitry Andric }
2847a6dacacSDimitry Andric }
2857a6dacacSDimitry Andric
2867a6dacacSDimitry Andric Out << "v";
2877a6dacacSDimitry Andric Arm64RetTy = VoidTy;
2887a6dacacSDimitry Andric X64RetTy = VoidTy;
2897a6dacacSDimitry Andric return;
2907a6dacacSDimitry Andric }
2917a6dacacSDimitry Andric
2920fca6ea1SDimitry Andric auto info =
2930fca6ea1SDimitry Andric canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out);
2940fca6ea1SDimitry Andric Arm64RetTy = info.Arm64Ty;
2950fca6ea1SDimitry Andric X64RetTy = info.X64Ty;
2967a6dacacSDimitry Andric if (X64RetTy->isPointerTy()) {
2977a6dacacSDimitry Andric // If the X64 type is canonicalized to a pointer, that means it's
2987a6dacacSDimitry Andric // passed/returned indirectly. For a return value, that means it's an
2997a6dacacSDimitry Andric // sret pointer.
3007a6dacacSDimitry Andric X64ArgTypes.push_back(X64RetTy);
3017a6dacacSDimitry Andric X64RetTy = VoidTy;
3027a6dacacSDimitry Andric }
3037a6dacacSDimitry Andric }
3047a6dacacSDimitry Andric
canonicalizeThunkType(Type * T,Align Alignment,bool Ret,uint64_t ArgSizeBytes,raw_ostream & Out)3050fca6ea1SDimitry Andric ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType(
3060fca6ea1SDimitry Andric Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes,
3070fca6ea1SDimitry Andric raw_ostream &Out) {
3080fca6ea1SDimitry Andric
3090fca6ea1SDimitry Andric auto direct = [](Type *T) {
3100fca6ea1SDimitry Andric return ThunkArgInfo{T, T, ThunkArgTranslation::Direct};
3110fca6ea1SDimitry Andric };
3120fca6ea1SDimitry Andric
3130fca6ea1SDimitry Andric auto bitcast = [this](Type *Arm64Ty, uint64_t SizeInBytes) {
3140fca6ea1SDimitry Andric return ThunkArgInfo{Arm64Ty,
3150fca6ea1SDimitry Andric llvm::Type::getIntNTy(M->getContext(), SizeInBytes * 8),
3160fca6ea1SDimitry Andric ThunkArgTranslation::Bitcast};
3170fca6ea1SDimitry Andric };
3180fca6ea1SDimitry Andric
3190fca6ea1SDimitry Andric auto pointerIndirection = [this](Type *Arm64Ty) {
3200fca6ea1SDimitry Andric return ThunkArgInfo{Arm64Ty, PtrTy,
3210fca6ea1SDimitry Andric ThunkArgTranslation::PointerIndirection};
3220fca6ea1SDimitry Andric };
3230fca6ea1SDimitry Andric
3247a6dacacSDimitry Andric if (T->isFloatTy()) {
3257a6dacacSDimitry Andric Out << "f";
3260fca6ea1SDimitry Andric return direct(T);
3277a6dacacSDimitry Andric }
3287a6dacacSDimitry Andric
3297a6dacacSDimitry Andric if (T->isDoubleTy()) {
3307a6dacacSDimitry Andric Out << "d";
3310fca6ea1SDimitry Andric return direct(T);
3327a6dacacSDimitry Andric }
3337a6dacacSDimitry Andric
3347a6dacacSDimitry Andric if (T->isFloatingPointTy()) {
3357a6dacacSDimitry Andric report_fatal_error(
3367a6dacacSDimitry Andric "Only 32 and 64 bit floating points are supported for ARM64EC thunks");
3377a6dacacSDimitry Andric }
3387a6dacacSDimitry Andric
3397a6dacacSDimitry Andric auto &DL = M->getDataLayout();
3407a6dacacSDimitry Andric
3417a6dacacSDimitry Andric if (auto *StructTy = dyn_cast<StructType>(T))
3427a6dacacSDimitry Andric if (StructTy->getNumElements() == 1)
3437a6dacacSDimitry Andric T = StructTy->getElementType(0);
3447a6dacacSDimitry Andric
3457a6dacacSDimitry Andric if (T->isArrayTy()) {
3467a6dacacSDimitry Andric Type *ElementTy = T->getArrayElementType();
3477a6dacacSDimitry Andric uint64_t ElementCnt = T->getArrayNumElements();
3487a6dacacSDimitry Andric uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8;
3497a6dacacSDimitry Andric uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes;
3507a6dacacSDimitry Andric if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) {
3517a6dacacSDimitry Andric Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes;
3523a079333SDimitry Andric if (Alignment.value() >= 16 && !Ret)
3537a6dacacSDimitry Andric Out << "a" << Alignment.value();
3547a6dacacSDimitry Andric if (TotalSizeBytes <= 8) {
3557a6dacacSDimitry Andric // Arm64 returns small structs of float/double in float registers;
3567a6dacacSDimitry Andric // X64 uses RAX.
3570fca6ea1SDimitry Andric return bitcast(T, TotalSizeBytes);
3587a6dacacSDimitry Andric } else {
3597a6dacacSDimitry Andric // Struct is passed directly on Arm64, but indirectly on X64.
3600fca6ea1SDimitry Andric return pointerIndirection(T);
3617a6dacacSDimitry Andric }
3627a6dacacSDimitry Andric } else if (T->isFloatingPointTy()) {
3637a6dacacSDimitry Andric report_fatal_error("Only 32 and 64 bit floating points are supported for "
3647a6dacacSDimitry Andric "ARM64EC thunks");
3657a6dacacSDimitry Andric }
3667a6dacacSDimitry Andric }
3677a6dacacSDimitry Andric
3687a6dacacSDimitry Andric if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) {
3697a6dacacSDimitry Andric Out << "i8";
3700fca6ea1SDimitry Andric return direct(I64Ty);
3717a6dacacSDimitry Andric }
3727a6dacacSDimitry Andric
3737a6dacacSDimitry Andric unsigned TypeSize = ArgSizeBytes;
3747a6dacacSDimitry Andric if (TypeSize == 0)
3757a6dacacSDimitry Andric TypeSize = DL.getTypeSizeInBits(T) / 8;
3767a6dacacSDimitry Andric Out << "m";
3777a6dacacSDimitry Andric if (TypeSize != 4)
3787a6dacacSDimitry Andric Out << TypeSize;
3793a079333SDimitry Andric if (Alignment.value() >= 16 && !Ret)
3807a6dacacSDimitry Andric Out << "a" << Alignment.value();
3817a6dacacSDimitry Andric // FIXME: Try to canonicalize Arm64Ty more thoroughly?
3827a6dacacSDimitry Andric if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) {
3837a6dacacSDimitry Andric // Pass directly in an integer register
3840fca6ea1SDimitry Andric return bitcast(T, TypeSize);
3857a6dacacSDimitry Andric } else {
3867a6dacacSDimitry Andric // Passed directly on Arm64, but indirectly on X64.
3870fca6ea1SDimitry Andric return pointerIndirection(T);
3887a6dacacSDimitry Andric }
3897a6dacacSDimitry Andric }
3907a6dacacSDimitry Andric
3917a6dacacSDimitry Andric // This function builds the "exit thunk", a function which translates
3927a6dacacSDimitry Andric // arguments and return values when calling x64 code from AArch64 code.
buildExitThunk(FunctionType * FT,AttributeList Attrs)3937a6dacacSDimitry Andric Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
3947a6dacacSDimitry Andric AttributeList Attrs) {
3957a6dacacSDimitry Andric SmallString<256> ExitThunkName;
3967a6dacacSDimitry Andric llvm::raw_svector_ostream ExitThunkStream(ExitThunkName);
3977a6dacacSDimitry Andric FunctionType *Arm64Ty, *X64Ty;
3980fca6ea1SDimitry Andric SmallVector<ThunkArgTranslation> ArgTranslations;
3990fca6ea1SDimitry Andric getThunkType(FT, Attrs, Arm64ECThunkType::Exit, ExitThunkStream, Arm64Ty,
4000fca6ea1SDimitry Andric X64Ty, ArgTranslations);
4017a6dacacSDimitry Andric if (Function *F = M->getFunction(ExitThunkName))
4027a6dacacSDimitry Andric return F;
4037a6dacacSDimitry Andric
4047a6dacacSDimitry Andric Function *F = Function::Create(Arm64Ty, GlobalValue::LinkOnceODRLinkage, 0,
4057a6dacacSDimitry Andric ExitThunkName, M);
4067a6dacacSDimitry Andric F->setCallingConv(CallingConv::ARM64EC_Thunk_Native);
4077a6dacacSDimitry Andric F->setSection(".wowthk$aa");
4087a6dacacSDimitry Andric F->setComdat(M->getOrInsertComdat(ExitThunkName));
4097a6dacacSDimitry Andric // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
4107a6dacacSDimitry Andric F->addFnAttr("frame-pointer", "all");
4117a6dacacSDimitry Andric // Only copy sret from the first argument. For C++ instance methods, clang can
4127a6dacacSDimitry Andric // stick an sret marking on a later argument, but it doesn't actually affect
4137a6dacacSDimitry Andric // the ABI, so we can omit it. This avoids triggering a verifier assertion.
4147a6dacacSDimitry Andric if (FT->getNumParams()) {
4157a6dacacSDimitry Andric auto SRet = Attrs.getParamAttr(0, Attribute::StructRet);
4167a6dacacSDimitry Andric auto InReg = Attrs.getParamAttr(0, Attribute::InReg);
4177a6dacacSDimitry Andric if (SRet.isValid() && !InReg.isValid())
4187a6dacacSDimitry Andric F->addParamAttr(1, SRet);
4197a6dacacSDimitry Andric }
4207a6dacacSDimitry Andric // FIXME: Copy anything other than sret? Shouldn't be necessary for normal
4217a6dacacSDimitry Andric // C ABI, but might show up in other cases.
4227a6dacacSDimitry Andric BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F);
4237a6dacacSDimitry Andric IRBuilder<> IRB(BB);
4247a6dacacSDimitry Andric Value *CalleePtr =
4257a6dacacSDimitry Andric M->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy);
4267a6dacacSDimitry Andric Value *Callee = IRB.CreateLoad(PtrTy, CalleePtr);
4277a6dacacSDimitry Andric auto &DL = M->getDataLayout();
4287a6dacacSDimitry Andric SmallVector<Value *> Args;
4297a6dacacSDimitry Andric
4307a6dacacSDimitry Andric // Pass the called function in x9.
4310fca6ea1SDimitry Andric auto X64TyOffset = 1;
4327a6dacacSDimitry Andric Args.push_back(F->arg_begin());
4337a6dacacSDimitry Andric
4347a6dacacSDimitry Andric Type *RetTy = Arm64Ty->getReturnType();
4357a6dacacSDimitry Andric if (RetTy != X64Ty->getReturnType()) {
4367a6dacacSDimitry Andric // If the return type is an array or struct, translate it. Values of size
4377a6dacacSDimitry Andric // 8 or less go into RAX; bigger values go into memory, and we pass a
4387a6dacacSDimitry Andric // pointer.
4397a6dacacSDimitry Andric if (DL.getTypeStoreSize(RetTy) > 8) {
4407a6dacacSDimitry Andric Args.push_back(IRB.CreateAlloca(RetTy));
4410fca6ea1SDimitry Andric X64TyOffset++;
4427a6dacacSDimitry Andric }
4437a6dacacSDimitry Andric }
4447a6dacacSDimitry Andric
4450fca6ea1SDimitry Andric for (auto [Arg, X64ArgType, ArgTranslation] : llvm::zip_equal(
4460fca6ea1SDimitry Andric make_range(F->arg_begin() + 1, F->arg_end()),
4470fca6ea1SDimitry Andric make_range(X64Ty->param_begin() + X64TyOffset, X64Ty->param_end()),
4480fca6ea1SDimitry Andric ArgTranslations)) {
4497a6dacacSDimitry Andric // Translate arguments from AArch64 calling convention to x86 calling
4507a6dacacSDimitry Andric // convention.
4517a6dacacSDimitry Andric //
4527a6dacacSDimitry Andric // For simple types, we don't need to do any translation: they're
4537a6dacacSDimitry Andric // represented the same way. (Implicit sign extension is not part of
4547a6dacacSDimitry Andric // either convention.)
4557a6dacacSDimitry Andric //
4567a6dacacSDimitry Andric // The big thing we have to worry about is struct types... but
4577a6dacacSDimitry Andric // fortunately AArch64 clang is pretty friendly here: the cases that need
4587a6dacacSDimitry Andric // translation are always passed as a struct or array. (If we run into
4597a6dacacSDimitry Andric // some cases where this doesn't work, we can teach clang to mark it up
4607a6dacacSDimitry Andric // with an attribute.)
4617a6dacacSDimitry Andric //
4627a6dacacSDimitry Andric // The first argument is the called function, stored in x9.
4630fca6ea1SDimitry Andric if (ArgTranslation != ThunkArgTranslation::Direct) {
4647a6dacacSDimitry Andric Value *Mem = IRB.CreateAlloca(Arg.getType());
4657a6dacacSDimitry Andric IRB.CreateStore(&Arg, Mem);
4660fca6ea1SDimitry Andric if (ArgTranslation == ThunkArgTranslation::Bitcast) {
4677a6dacacSDimitry Andric Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType()));
4687a6dacacSDimitry Andric Args.push_back(IRB.CreateLoad(IntTy, IRB.CreateBitCast(Mem, PtrTy)));
4690fca6ea1SDimitry Andric } else {
4700fca6ea1SDimitry Andric assert(ArgTranslation == ThunkArgTranslation::PointerIndirection);
4717a6dacacSDimitry Andric Args.push_back(Mem);
4720fca6ea1SDimitry Andric }
4737a6dacacSDimitry Andric } else {
4747a6dacacSDimitry Andric Args.push_back(&Arg);
4757a6dacacSDimitry Andric }
4760fca6ea1SDimitry Andric assert(Args.back()->getType() == X64ArgType);
4777a6dacacSDimitry Andric }
4787a6dacacSDimitry Andric // FIXME: Transfer necessary attributes? sret? anything else?
4797a6dacacSDimitry Andric
4807a6dacacSDimitry Andric Callee = IRB.CreateBitCast(Callee, PtrTy);
4817a6dacacSDimitry Andric CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args);
4827a6dacacSDimitry Andric Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
4837a6dacacSDimitry Andric
4847a6dacacSDimitry Andric Value *RetVal = Call;
4857a6dacacSDimitry Andric if (RetTy != X64Ty->getReturnType()) {
4867a6dacacSDimitry Andric // If we rewrote the return type earlier, convert the return value to
4877a6dacacSDimitry Andric // the proper type.
4887a6dacacSDimitry Andric if (DL.getTypeStoreSize(RetTy) > 8) {
4897a6dacacSDimitry Andric RetVal = IRB.CreateLoad(RetTy, Args[1]);
4907a6dacacSDimitry Andric } else {
4917a6dacacSDimitry Andric Value *CastAlloca = IRB.CreateAlloca(RetTy);
4927a6dacacSDimitry Andric IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
4937a6dacacSDimitry Andric RetVal = IRB.CreateLoad(RetTy, CastAlloca);
4947a6dacacSDimitry Andric }
4957a6dacacSDimitry Andric }
4967a6dacacSDimitry Andric
4977a6dacacSDimitry Andric if (RetTy->isVoidTy())
4987a6dacacSDimitry Andric IRB.CreateRetVoid();
4997a6dacacSDimitry Andric else
5007a6dacacSDimitry Andric IRB.CreateRet(RetVal);
5017a6dacacSDimitry Andric return F;
5027a6dacacSDimitry Andric }
5037a6dacacSDimitry Andric
5047a6dacacSDimitry Andric // This function builds the "entry thunk", a function which translates
5057a6dacacSDimitry Andric // arguments and return values when calling AArch64 code from x64 code.
buildEntryThunk(Function * F)5067a6dacacSDimitry Andric Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
5077a6dacacSDimitry Andric SmallString<256> EntryThunkName;
5087a6dacacSDimitry Andric llvm::raw_svector_ostream EntryThunkStream(EntryThunkName);
5097a6dacacSDimitry Andric FunctionType *Arm64Ty, *X64Ty;
5100fca6ea1SDimitry Andric SmallVector<ThunkArgTranslation> ArgTranslations;
5110fca6ea1SDimitry Andric getThunkType(F->getFunctionType(), F->getAttributes(),
5120fca6ea1SDimitry Andric Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty,
5130fca6ea1SDimitry Andric ArgTranslations);
5147a6dacacSDimitry Andric if (Function *F = M->getFunction(EntryThunkName))
5157a6dacacSDimitry Andric return F;
5167a6dacacSDimitry Andric
5177a6dacacSDimitry Andric Function *Thunk = Function::Create(X64Ty, GlobalValue::LinkOnceODRLinkage, 0,
5187a6dacacSDimitry Andric EntryThunkName, M);
5197a6dacacSDimitry Andric Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
5207a6dacacSDimitry Andric Thunk->setSection(".wowthk$aa");
5217a6dacacSDimitry Andric Thunk->setComdat(M->getOrInsertComdat(EntryThunkName));
5227a6dacacSDimitry Andric // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
5237a6dacacSDimitry Andric Thunk->addFnAttr("frame-pointer", "all");
5247a6dacacSDimitry Andric
5257a6dacacSDimitry Andric BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk);
5267a6dacacSDimitry Andric IRBuilder<> IRB(BB);
5277a6dacacSDimitry Andric
5287a6dacacSDimitry Andric Type *RetTy = Arm64Ty->getReturnType();
5297a6dacacSDimitry Andric Type *X64RetType = X64Ty->getReturnType();
5307a6dacacSDimitry Andric
5317a6dacacSDimitry Andric bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy();
5327a6dacacSDimitry Andric unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
5330fca6ea1SDimitry Andric unsigned PassthroughArgSize =
5340fca6ea1SDimitry Andric (F->isVarArg() ? 5 : Thunk->arg_size()) - ThunkArgOffset;
5350fca6ea1SDimitry Andric assert(ArgTranslations.size() == (F->isVarArg() ? 5 : PassthroughArgSize));
5367a6dacacSDimitry Andric
5377a6dacacSDimitry Andric // Translate arguments to call.
5387a6dacacSDimitry Andric SmallVector<Value *> Args;
5390fca6ea1SDimitry Andric for (unsigned i = 0; i != PassthroughArgSize; ++i) {
5400fca6ea1SDimitry Andric Value *Arg = Thunk->getArg(i + ThunkArgOffset);
5410fca6ea1SDimitry Andric Type *ArgTy = Arm64Ty->getParamType(i);
5420fca6ea1SDimitry Andric ThunkArgTranslation ArgTranslation = ArgTranslations[i];
5430fca6ea1SDimitry Andric if (ArgTranslation != ThunkArgTranslation::Direct) {
5447a6dacacSDimitry Andric // Translate array/struct arguments to the expected type.
5450fca6ea1SDimitry Andric if (ArgTranslation == ThunkArgTranslation::Bitcast) {
5467a6dacacSDimitry Andric Value *CastAlloca = IRB.CreateAlloca(ArgTy);
5477a6dacacSDimitry Andric IRB.CreateStore(Arg, IRB.CreateBitCast(CastAlloca, PtrTy));
5487a6dacacSDimitry Andric Arg = IRB.CreateLoad(ArgTy, CastAlloca);
5497a6dacacSDimitry Andric } else {
5500fca6ea1SDimitry Andric assert(ArgTranslation == ThunkArgTranslation::PointerIndirection);
5517a6dacacSDimitry Andric Arg = IRB.CreateLoad(ArgTy, IRB.CreateBitCast(Arg, PtrTy));
5527a6dacacSDimitry Andric }
5537a6dacacSDimitry Andric }
5540fca6ea1SDimitry Andric assert(Arg->getType() == ArgTy);
5557a6dacacSDimitry Andric Args.push_back(Arg);
5567a6dacacSDimitry Andric }
5577a6dacacSDimitry Andric
558439352acSDimitry Andric if (F->isVarArg()) {
559439352acSDimitry Andric // The 5th argument to variadic entry thunks is used to model the x64 sp
560439352acSDimitry Andric // which is passed to the thunk in x4, this can be passed to the callee as
561439352acSDimitry Andric // the variadic argument start address after skipping over the 32 byte
562439352acSDimitry Andric // shadow store.
563439352acSDimitry Andric
564439352acSDimitry Andric // The EC thunk CC will assign any argument marked as InReg to x4.
565439352acSDimitry Andric Thunk->addParamAttr(5, Attribute::InReg);
566439352acSDimitry Andric Value *Arg = Thunk->getArg(5);
567439352acSDimitry Andric Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20));
568439352acSDimitry Andric Args.push_back(Arg);
569439352acSDimitry Andric
570439352acSDimitry Andric // Pass in a zero variadic argument size (in x5).
571439352acSDimitry Andric Args.push_back(IRB.getInt64(0));
572439352acSDimitry Andric }
573439352acSDimitry Andric
5747a6dacacSDimitry Andric // Call the function passed to the thunk.
5757a6dacacSDimitry Andric Value *Callee = Thunk->getArg(0);
5767a6dacacSDimitry Andric Callee = IRB.CreateBitCast(Callee, PtrTy);
5773a079333SDimitry Andric CallInst *Call = IRB.CreateCall(Arm64Ty, Callee, Args);
5783a079333SDimitry Andric
5793a079333SDimitry Andric auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
5803a079333SDimitry Andric auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
5813a079333SDimitry Andric if (SRetAttr.isValid() && !InRegAttr.isValid()) {
5823a079333SDimitry Andric Thunk->addParamAttr(1, SRetAttr);
5833a079333SDimitry Andric Call->addParamAttr(0, SRetAttr);
5843a079333SDimitry Andric }
5857a6dacacSDimitry Andric
5867a6dacacSDimitry Andric Value *RetVal = Call;
5877a6dacacSDimitry Andric if (TransformDirectToSRet) {
5887a6dacacSDimitry Andric IRB.CreateStore(RetVal, IRB.CreateBitCast(Thunk->getArg(1), PtrTy));
5897a6dacacSDimitry Andric } else if (X64RetType != RetTy) {
5907a6dacacSDimitry Andric Value *CastAlloca = IRB.CreateAlloca(X64RetType);
5917a6dacacSDimitry Andric IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
5927a6dacacSDimitry Andric RetVal = IRB.CreateLoad(X64RetType, CastAlloca);
5937a6dacacSDimitry Andric }
5947a6dacacSDimitry Andric
5957a6dacacSDimitry Andric // Return to the caller. Note that the isel has code to translate this
5967a6dacacSDimitry Andric // "ret" to a tail call to __os_arm64x_dispatch_ret. (Alternatively, we
5977a6dacacSDimitry Andric // could emit a tail call here, but that would require a dedicated calling
5987a6dacacSDimitry Andric // convention, which seems more complicated overall.)
5997a6dacacSDimitry Andric if (X64RetType->isVoidTy())
6007a6dacacSDimitry Andric IRB.CreateRetVoid();
6017a6dacacSDimitry Andric else
6027a6dacacSDimitry Andric IRB.CreateRet(RetVal);
6037a6dacacSDimitry Andric
6047a6dacacSDimitry Andric return Thunk;
6057a6dacacSDimitry Andric }
6067a6dacacSDimitry Andric
6077a6dacacSDimitry Andric // Builds the "guest exit thunk", a helper to call a function which may or may
6087a6dacacSDimitry Andric // not be an exit thunk. (We optimistically assume non-dllimport function
6097a6dacacSDimitry Andric // declarations refer to functions defined in AArch64 code; if the linker
6107a6dacacSDimitry Andric // can't prove that, we use this routine instead.)
buildGuestExitThunk(Function * F)6117a6dacacSDimitry Andric Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
6127a6dacacSDimitry Andric llvm::raw_null_ostream NullThunkName;
6137a6dacacSDimitry Andric FunctionType *Arm64Ty, *X64Ty;
6140fca6ea1SDimitry Andric SmallVector<ThunkArgTranslation> ArgTranslations;
6150fca6ea1SDimitry Andric getThunkType(F->getFunctionType(), F->getAttributes(),
6160fca6ea1SDimitry Andric Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
6170fca6ea1SDimitry Andric ArgTranslations);
6187a6dacacSDimitry Andric auto MangledName = getArm64ECMangledFunctionName(F->getName().str());
6197a6dacacSDimitry Andric assert(MangledName && "Can't guest exit to function that's already native");
6207a6dacacSDimitry Andric std::string ThunkName = *MangledName;
6217a6dacacSDimitry Andric if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
6227a6dacacSDimitry Andric ThunkName.insert(ThunkName.find("@"), "$exit_thunk");
6237a6dacacSDimitry Andric } else {
6247a6dacacSDimitry Andric ThunkName.append("$exit_thunk");
6257a6dacacSDimitry Andric }
6267a6dacacSDimitry Andric Function *GuestExit =
6277a6dacacSDimitry Andric Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
6287a6dacacSDimitry Andric GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
6297a6dacacSDimitry Andric GuestExit->setSection(".wowthk$aa");
6307a6dacacSDimitry Andric GuestExit->setMetadata(
6317a6dacacSDimitry Andric "arm64ec_unmangled_name",
6327a6dacacSDimitry Andric MDNode::get(M->getContext(),
6337a6dacacSDimitry Andric MDString::get(M->getContext(), F->getName())));
6347a6dacacSDimitry Andric GuestExit->setMetadata(
6357a6dacacSDimitry Andric "arm64ec_ecmangled_name",
6367a6dacacSDimitry Andric MDNode::get(M->getContext(),
6377a6dacacSDimitry Andric MDString::get(M->getContext(), *MangledName)));
6387a6dacacSDimitry Andric F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {}));
6397a6dacacSDimitry Andric BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
6407a6dacacSDimitry Andric IRBuilder<> B(BB);
6417a6dacacSDimitry Andric
6427a6dacacSDimitry Andric // Load the global symbol as a pointer to the check function.
6437a6dacacSDimitry Andric Value *GuardFn;
6447a6dacacSDimitry Andric if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf"))
6457a6dacacSDimitry Andric GuardFn = GuardFnCFGlobal;
6467a6dacacSDimitry Andric else
6477a6dacacSDimitry Andric GuardFn = GuardFnGlobal;
6487a6dacacSDimitry Andric LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
6497a6dacacSDimitry Andric
6507a6dacacSDimitry Andric // Create new call instruction. The CFGuard check should always be a call,
6517a6dacacSDimitry Andric // even if the original CallBase is an Invoke or CallBr instruction.
6527a6dacacSDimitry Andric Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes());
6537a6dacacSDimitry Andric CallInst *GuardCheck = B.CreateCall(
6547a6dacacSDimitry Andric GuardFnType, GuardCheckLoad,
6557a6dacacSDimitry Andric {B.CreateBitCast(F, B.getPtrTy()), B.CreateBitCast(Thunk, B.getPtrTy())});
6567a6dacacSDimitry Andric
6577a6dacacSDimitry Andric // Ensure that the first argument is passed in the correct register.
6587a6dacacSDimitry Andric GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
6597a6dacacSDimitry Andric
6607a6dacacSDimitry Andric Value *GuardRetVal = B.CreateBitCast(GuardCheck, PtrTy);
6617a6dacacSDimitry Andric SmallVector<Value *> Args;
6627a6dacacSDimitry Andric for (Argument &Arg : GuestExit->args())
6637a6dacacSDimitry Andric Args.push_back(&Arg);
6647a6dacacSDimitry Andric CallInst *Call = B.CreateCall(Arm64Ty, GuardRetVal, Args);
6657a6dacacSDimitry Andric Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
6667a6dacacSDimitry Andric
6677a6dacacSDimitry Andric if (Call->getType()->isVoidTy())
6687a6dacacSDimitry Andric B.CreateRetVoid();
6697a6dacacSDimitry Andric else
6707a6dacacSDimitry Andric B.CreateRet(Call);
6717a6dacacSDimitry Andric
6727a6dacacSDimitry Andric auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
6737a6dacacSDimitry Andric auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
6747a6dacacSDimitry Andric if (SRetAttr.isValid() && !InRegAttr.isValid()) {
6757a6dacacSDimitry Andric GuestExit->addParamAttr(0, SRetAttr);
6767a6dacacSDimitry Andric Call->addParamAttr(0, SRetAttr);
6777a6dacacSDimitry Andric }
6787a6dacacSDimitry Andric
6797a6dacacSDimitry Andric return GuestExit;
6807a6dacacSDimitry Andric }
6817a6dacacSDimitry Andric
6820fca6ea1SDimitry Andric Function *
buildPatchableThunk(GlobalAlias * UnmangledAlias,GlobalAlias * MangledAlias)6830fca6ea1SDimitry Andric AArch64Arm64ECCallLowering::buildPatchableThunk(GlobalAlias *UnmangledAlias,
6840fca6ea1SDimitry Andric GlobalAlias *MangledAlias) {
6850fca6ea1SDimitry Andric llvm::raw_null_ostream NullThunkName;
6860fca6ea1SDimitry Andric FunctionType *Arm64Ty, *X64Ty;
6870fca6ea1SDimitry Andric Function *F = cast<Function>(MangledAlias->getAliasee());
6880fca6ea1SDimitry Andric SmallVector<ThunkArgTranslation> ArgTranslations;
6890fca6ea1SDimitry Andric getThunkType(F->getFunctionType(), F->getAttributes(),
6900fca6ea1SDimitry Andric Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
6910fca6ea1SDimitry Andric ArgTranslations);
6920fca6ea1SDimitry Andric std::string ThunkName(MangledAlias->getName());
6930fca6ea1SDimitry Andric if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
6940fca6ea1SDimitry Andric ThunkName.insert(ThunkName.find("@"), "$hybpatch_thunk");
6950fca6ea1SDimitry Andric } else {
6960fca6ea1SDimitry Andric ThunkName.append("$hybpatch_thunk");
6970fca6ea1SDimitry Andric }
6980fca6ea1SDimitry Andric
6990fca6ea1SDimitry Andric Function *GuestExit =
7000fca6ea1SDimitry Andric Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
7010fca6ea1SDimitry Andric GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
7020fca6ea1SDimitry Andric GuestExit->setSection(".wowthk$aa");
7030fca6ea1SDimitry Andric BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
7040fca6ea1SDimitry Andric IRBuilder<> B(BB);
7050fca6ea1SDimitry Andric
7060fca6ea1SDimitry Andric // Load the global symbol as a pointer to the check function.
7070fca6ea1SDimitry Andric LoadInst *DispatchLoad = B.CreateLoad(DispatchFnPtrType, DispatchFnGlobal);
7080fca6ea1SDimitry Andric
7090fca6ea1SDimitry Andric // Create new dispatch call instruction.
7100fca6ea1SDimitry Andric Function *ExitThunk =
7110fca6ea1SDimitry Andric buildExitThunk(F->getFunctionType(), F->getAttributes());
7120fca6ea1SDimitry Andric CallInst *Dispatch =
7130fca6ea1SDimitry Andric B.CreateCall(DispatchFnType, DispatchLoad,
7140fca6ea1SDimitry Andric {UnmangledAlias, ExitThunk, UnmangledAlias->getAliasee()});
7150fca6ea1SDimitry Andric
7160fca6ea1SDimitry Andric // Ensure that the first arguments are passed in the correct registers.
7170fca6ea1SDimitry Andric Dispatch->setCallingConv(CallingConv::CFGuard_Check);
7180fca6ea1SDimitry Andric
7190fca6ea1SDimitry Andric Value *DispatchRetVal = B.CreateBitCast(Dispatch, PtrTy);
7200fca6ea1SDimitry Andric SmallVector<Value *> Args;
7210fca6ea1SDimitry Andric for (Argument &Arg : GuestExit->args())
7220fca6ea1SDimitry Andric Args.push_back(&Arg);
7230fca6ea1SDimitry Andric CallInst *Call = B.CreateCall(Arm64Ty, DispatchRetVal, Args);
7240fca6ea1SDimitry Andric Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
7250fca6ea1SDimitry Andric
7260fca6ea1SDimitry Andric if (Call->getType()->isVoidTy())
7270fca6ea1SDimitry Andric B.CreateRetVoid();
7280fca6ea1SDimitry Andric else
7290fca6ea1SDimitry Andric B.CreateRet(Call);
7300fca6ea1SDimitry Andric
7310fca6ea1SDimitry Andric auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
7320fca6ea1SDimitry Andric auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
7330fca6ea1SDimitry Andric if (SRetAttr.isValid() && !InRegAttr.isValid()) {
7340fca6ea1SDimitry Andric GuestExit->addParamAttr(0, SRetAttr);
7350fca6ea1SDimitry Andric Call->addParamAttr(0, SRetAttr);
7360fca6ea1SDimitry Andric }
7370fca6ea1SDimitry Andric
7380fca6ea1SDimitry Andric MangledAlias->setAliasee(GuestExit);
7390fca6ea1SDimitry Andric return GuestExit;
7400fca6ea1SDimitry Andric }
7410fca6ea1SDimitry Andric
7427a6dacacSDimitry Andric // Lower an indirect call with inline code.
lowerCall(CallBase * CB)7437a6dacacSDimitry Andric void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
7447a6dacacSDimitry Andric assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() &&
7457a6dacacSDimitry Andric "Only applicable for Windows targets");
7467a6dacacSDimitry Andric
7477a6dacacSDimitry Andric IRBuilder<> B(CB);
7487a6dacacSDimitry Andric Value *CalledOperand = CB->getCalledOperand();
7497a6dacacSDimitry Andric
7507a6dacacSDimitry Andric // If the indirect call is called within catchpad or cleanuppad,
7517a6dacacSDimitry Andric // we need to copy "funclet" bundle of the call.
7527a6dacacSDimitry Andric SmallVector<llvm::OperandBundleDef, 1> Bundles;
7537a6dacacSDimitry Andric if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet))
7547a6dacacSDimitry Andric Bundles.push_back(OperandBundleDef(*Bundle));
7557a6dacacSDimitry Andric
7567a6dacacSDimitry Andric // Load the global symbol as a pointer to the check function.
7577a6dacacSDimitry Andric Value *GuardFn;
7587a6dacacSDimitry Andric if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf"))
7597a6dacacSDimitry Andric GuardFn = GuardFnCFGlobal;
7607a6dacacSDimitry Andric else
7617a6dacacSDimitry Andric GuardFn = GuardFnGlobal;
7627a6dacacSDimitry Andric LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
7637a6dacacSDimitry Andric
7647a6dacacSDimitry Andric // Create new call instruction. The CFGuard check should always be a call,
7657a6dacacSDimitry Andric // even if the original CallBase is an Invoke or CallBr instruction.
7667a6dacacSDimitry Andric Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes());
7677a6dacacSDimitry Andric CallInst *GuardCheck =
7687a6dacacSDimitry Andric B.CreateCall(GuardFnType, GuardCheckLoad,
7697a6dacacSDimitry Andric {B.CreateBitCast(CalledOperand, B.getPtrTy()),
7707a6dacacSDimitry Andric B.CreateBitCast(Thunk, B.getPtrTy())},
7717a6dacacSDimitry Andric Bundles);
7727a6dacacSDimitry Andric
7737a6dacacSDimitry Andric // Ensure that the first argument is passed in the correct register.
7747a6dacacSDimitry Andric GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
7757a6dacacSDimitry Andric
7767a6dacacSDimitry Andric Value *GuardRetVal = B.CreateBitCast(GuardCheck, CalledOperand->getType());
7777a6dacacSDimitry Andric CB->setCalledOperand(GuardRetVal);
7787a6dacacSDimitry Andric }
7797a6dacacSDimitry Andric
runOnModule(Module & Mod)7807a6dacacSDimitry Andric bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
7817a6dacacSDimitry Andric if (!GenerateThunks)
7827a6dacacSDimitry Andric return false;
7837a6dacacSDimitry Andric
7847a6dacacSDimitry Andric M = &Mod;
7857a6dacacSDimitry Andric
7867a6dacacSDimitry Andric // Check if this module has the cfguard flag and read its value.
7877a6dacacSDimitry Andric if (auto *MD =
7887a6dacacSDimitry Andric mdconst::extract_or_null<ConstantInt>(M->getModuleFlag("cfguard")))
7897a6dacacSDimitry Andric cfguard_module_flag = MD->getZExtValue();
7907a6dacacSDimitry Andric
7917a6dacacSDimitry Andric PtrTy = PointerType::getUnqual(M->getContext());
7927a6dacacSDimitry Andric I64Ty = Type::getInt64Ty(M->getContext());
7937a6dacacSDimitry Andric VoidTy = Type::getVoidTy(M->getContext());
7947a6dacacSDimitry Andric
7957a6dacacSDimitry Andric GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
7967a6dacacSDimitry Andric GuardFnPtrType = PointerType::get(GuardFnType, 0);
7970fca6ea1SDimitry Andric DispatchFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy, PtrTy}, false);
7980fca6ea1SDimitry Andric DispatchFnPtrType = PointerType::get(DispatchFnType, 0);
7997a6dacacSDimitry Andric GuardFnCFGlobal =
8007a6dacacSDimitry Andric M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType);
8017a6dacacSDimitry Andric GuardFnGlobal =
8027a6dacacSDimitry Andric M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType);
8030fca6ea1SDimitry Andric DispatchFnGlobal =
8040fca6ea1SDimitry Andric M->getOrInsertGlobal("__os_arm64x_dispatch_call", DispatchFnPtrType);
8057a6dacacSDimitry Andric
8060fca6ea1SDimitry Andric DenseMap<GlobalAlias *, GlobalAlias *> FnsMap;
8070fca6ea1SDimitry Andric SetVector<GlobalAlias *> PatchableFns;
8080fca6ea1SDimitry Andric
8090fca6ea1SDimitry Andric for (Function &F : Mod) {
8100fca6ea1SDimitry Andric if (!F.hasFnAttribute(Attribute::HybridPatchable) || F.isDeclaration() ||
8110fca6ea1SDimitry Andric F.hasLocalLinkage() || F.getName().ends_with("$hp_target"))
8120fca6ea1SDimitry Andric continue;
8130fca6ea1SDimitry Andric
8140fca6ea1SDimitry Andric // Rename hybrid patchable functions and change callers to use a global
8150fca6ea1SDimitry Andric // alias instead.
8160fca6ea1SDimitry Andric if (std::optional<std::string> MangledName =
8170fca6ea1SDimitry Andric getArm64ECMangledFunctionName(F.getName().str())) {
8180fca6ea1SDimitry Andric std::string OrigName(F.getName());
8190fca6ea1SDimitry Andric F.setName(MangledName.value() + "$hp_target");
8200fca6ea1SDimitry Andric
8210fca6ea1SDimitry Andric // The unmangled symbol is a weak alias to an undefined symbol with the
8220fca6ea1SDimitry Andric // "EXP+" prefix. This undefined symbol is resolved by the linker by
8230fca6ea1SDimitry Andric // creating an x86 thunk that jumps back to the actual EC target. Since we
8240fca6ea1SDimitry Andric // can't represent that in IR, we create an alias to the target instead.
8250fca6ea1SDimitry Andric // The "EXP+" symbol is set as metadata, which is then used by
8260fca6ea1SDimitry Andric // emitGlobalAlias to emit the right alias.
8270fca6ea1SDimitry Andric auto *A =
8280fca6ea1SDimitry Andric GlobalAlias::create(GlobalValue::LinkOnceODRLinkage, OrigName, &F);
8290fca6ea1SDimitry Andric F.replaceAllUsesWith(A);
8300fca6ea1SDimitry Andric F.setMetadata("arm64ec_exp_name",
8310fca6ea1SDimitry Andric MDNode::get(M->getContext(),
8320fca6ea1SDimitry Andric MDString::get(M->getContext(),
8330fca6ea1SDimitry Andric "EXP+" + MangledName.value())));
8340fca6ea1SDimitry Andric A->setAliasee(&F);
8350fca6ea1SDimitry Andric
836*52418fc2SDimitry Andric if (F.hasDLLExportStorageClass()) {
837*52418fc2SDimitry Andric A->setDLLStorageClass(GlobalValue::DLLExportStorageClass);
838*52418fc2SDimitry Andric F.setDLLStorageClass(GlobalValue::DefaultStorageClass);
839*52418fc2SDimitry Andric }
840*52418fc2SDimitry Andric
8410fca6ea1SDimitry Andric FnsMap[A] = GlobalAlias::create(GlobalValue::LinkOnceODRLinkage,
8420fca6ea1SDimitry Andric MangledName.value(), &F);
8430fca6ea1SDimitry Andric PatchableFns.insert(A);
8440fca6ea1SDimitry Andric }
8450fca6ea1SDimitry Andric }
8460fca6ea1SDimitry Andric
8470fca6ea1SDimitry Andric SetVector<GlobalValue *> DirectCalledFns;
8487a6dacacSDimitry Andric for (Function &F : Mod)
8497a6dacacSDimitry Andric if (!F.isDeclaration() &&
8507a6dacacSDimitry Andric F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
8517a6dacacSDimitry Andric F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)
8520fca6ea1SDimitry Andric processFunction(F, DirectCalledFns, FnsMap);
8537a6dacacSDimitry Andric
8547a6dacacSDimitry Andric struct ThunkInfo {
8557a6dacacSDimitry Andric Constant *Src;
8567a6dacacSDimitry Andric Constant *Dst;
8570fca6ea1SDimitry Andric Arm64ECThunkType Kind;
8587a6dacacSDimitry Andric };
8597a6dacacSDimitry Andric SmallVector<ThunkInfo> ThunkMapping;
8607a6dacacSDimitry Andric for (Function &F : Mod) {
8617a6dacacSDimitry Andric if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) &&
8627a6dacacSDimitry Andric F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
8637a6dacacSDimitry Andric F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) {
8647a6dacacSDimitry Andric if (!F.hasComdat())
8657a6dacacSDimitry Andric F.setComdat(Mod.getOrInsertComdat(F.getName()));
8667a6dacacSDimitry Andric ThunkMapping.push_back(
8670fca6ea1SDimitry Andric {&F, buildEntryThunk(&F), Arm64ECThunkType::Entry});
8680fca6ea1SDimitry Andric }
8690fca6ea1SDimitry Andric }
8700fca6ea1SDimitry Andric for (GlobalValue *O : DirectCalledFns) {
8710fca6ea1SDimitry Andric auto GA = dyn_cast<GlobalAlias>(O);
8720fca6ea1SDimitry Andric auto F = dyn_cast<Function>(GA ? GA->getAliasee() : O);
8730fca6ea1SDimitry Andric ThunkMapping.push_back(
8740fca6ea1SDimitry Andric {O, buildExitThunk(F->getFunctionType(), F->getAttributes()),
8750fca6ea1SDimitry Andric Arm64ECThunkType::Exit});
8760fca6ea1SDimitry Andric if (!GA && !F->hasDLLImportStorageClass())
8770fca6ea1SDimitry Andric ThunkMapping.push_back(
8780fca6ea1SDimitry Andric {buildGuestExitThunk(F), F, Arm64ECThunkType::GuestExit});
8790fca6ea1SDimitry Andric }
8800fca6ea1SDimitry Andric for (GlobalAlias *A : PatchableFns) {
8810fca6ea1SDimitry Andric Function *Thunk = buildPatchableThunk(A, FnsMap[A]);
8820fca6ea1SDimitry Andric ThunkMapping.push_back({Thunk, A, Arm64ECThunkType::GuestExit});
8837a6dacacSDimitry Andric }
8847a6dacacSDimitry Andric
8857a6dacacSDimitry Andric if (!ThunkMapping.empty()) {
8867a6dacacSDimitry Andric SmallVector<Constant *> ThunkMappingArrayElems;
8877a6dacacSDimitry Andric for (ThunkInfo &Thunk : ThunkMapping) {
8887a6dacacSDimitry Andric ThunkMappingArrayElems.push_back(ConstantStruct::getAnon(
8897a6dacacSDimitry Andric {ConstantExpr::getBitCast(Thunk.Src, PtrTy),
8907a6dacacSDimitry Andric ConstantExpr::getBitCast(Thunk.Dst, PtrTy),
8910fca6ea1SDimitry Andric ConstantInt::get(M->getContext(), APInt(32, uint8_t(Thunk.Kind)))}));
8927a6dacacSDimitry Andric }
8937a6dacacSDimitry Andric Constant *ThunkMappingArray = ConstantArray::get(
8947a6dacacSDimitry Andric llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(),
8957a6dacacSDimitry Andric ThunkMappingArrayElems.size()),
8967a6dacacSDimitry Andric ThunkMappingArrayElems);
8977a6dacacSDimitry Andric new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false,
8987a6dacacSDimitry Andric GlobalValue::ExternalLinkage, ThunkMappingArray,
8997a6dacacSDimitry Andric "llvm.arm64ec.symbolmap");
9007a6dacacSDimitry Andric }
9017a6dacacSDimitry Andric
9027a6dacacSDimitry Andric return true;
9037a6dacacSDimitry Andric }
9047a6dacacSDimitry Andric
processFunction(Function & F,SetVector<GlobalValue * > & DirectCalledFns,DenseMap<GlobalAlias *,GlobalAlias * > & FnsMap)9057a6dacacSDimitry Andric bool AArch64Arm64ECCallLowering::processFunction(
9060fca6ea1SDimitry Andric Function &F, SetVector<GlobalValue *> &DirectCalledFns,
9070fca6ea1SDimitry Andric DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap) {
9087a6dacacSDimitry Andric SmallVector<CallBase *, 8> IndirectCalls;
9097a6dacacSDimitry Andric
9107a6dacacSDimitry Andric // For ARM64EC targets, a function definition's name is mangled differently
9117a6dacacSDimitry Andric // from the normal symbol. We currently have no representation of this sort
9127a6dacacSDimitry Andric // of symbol in IR, so we change the name to the mangled name, then store
9137a6dacacSDimitry Andric // the unmangled name as metadata. Later passes that need the unmangled
9147a6dacacSDimitry Andric // name (emitting the definition) can grab it from the metadata.
9157a6dacacSDimitry Andric //
9167a6dacacSDimitry Andric // FIXME: Handle functions with weak linkage?
9170fca6ea1SDimitry Andric if (!F.hasLocalLinkage() || F.hasAddressTaken()) {
9187a6dacacSDimitry Andric if (std::optional<std::string> MangledName =
9197a6dacacSDimitry Andric getArm64ECMangledFunctionName(F.getName().str())) {
9207a6dacacSDimitry Andric F.setMetadata("arm64ec_unmangled_name",
9217a6dacacSDimitry Andric MDNode::get(M->getContext(),
9227a6dacacSDimitry Andric MDString::get(M->getContext(), F.getName())));
9237a6dacacSDimitry Andric if (F.hasComdat() && F.getComdat()->getName() == F.getName()) {
9247a6dacacSDimitry Andric Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value());
9257a6dacacSDimitry Andric SmallVector<GlobalObject *> ComdatUsers =
9267a6dacacSDimitry Andric to_vector(F.getComdat()->getUsers());
9277a6dacacSDimitry Andric for (GlobalObject *User : ComdatUsers)
9287a6dacacSDimitry Andric User->setComdat(MangledComdat);
9297a6dacacSDimitry Andric }
9307a6dacacSDimitry Andric F.setName(MangledName.value());
9317a6dacacSDimitry Andric }
9327a6dacacSDimitry Andric }
9337a6dacacSDimitry Andric
9347a6dacacSDimitry Andric // Iterate over the instructions to find all indirect call/invoke/callbr
9357a6dacacSDimitry Andric // instructions. Make a separate list of pointers to indirect
9367a6dacacSDimitry Andric // call/invoke/callbr instructions because the original instructions will be
9377a6dacacSDimitry Andric // deleted as the checks are added.
9387a6dacacSDimitry Andric for (BasicBlock &BB : F) {
9397a6dacacSDimitry Andric for (Instruction &I : BB) {
9407a6dacacSDimitry Andric auto *CB = dyn_cast<CallBase>(&I);
9417a6dacacSDimitry Andric if (!CB || CB->getCallingConv() == CallingConv::ARM64EC_Thunk_X64 ||
9427a6dacacSDimitry Andric CB->isInlineAsm())
9437a6dacacSDimitry Andric continue;
9447a6dacacSDimitry Andric
9457a6dacacSDimitry Andric // We need to instrument any call that isn't directly calling an
9467a6dacacSDimitry Andric // ARM64 function.
9477a6dacacSDimitry Andric //
9487a6dacacSDimitry Andric // FIXME: getCalledFunction() fails if there's a bitcast (e.g.
9497a6dacacSDimitry Andric // unprototyped functions in C)
9507a6dacacSDimitry Andric if (Function *F = CB->getCalledFunction()) {
9517a6dacacSDimitry Andric if (!LowerDirectToIndirect || F->hasLocalLinkage() ||
9527a6dacacSDimitry Andric F->isIntrinsic() || !F->isDeclaration())
9537a6dacacSDimitry Andric continue;
9547a6dacacSDimitry Andric
9557a6dacacSDimitry Andric DirectCalledFns.insert(F);
9567a6dacacSDimitry Andric continue;
9577a6dacacSDimitry Andric }
9587a6dacacSDimitry Andric
9590fca6ea1SDimitry Andric // Use mangled global alias for direct calls to patchable functions.
9600fca6ea1SDimitry Andric if (GlobalAlias *A = dyn_cast<GlobalAlias>(CB->getCalledOperand())) {
9610fca6ea1SDimitry Andric auto I = FnsMap.find(A);
9620fca6ea1SDimitry Andric if (I != FnsMap.end()) {
9630fca6ea1SDimitry Andric CB->setCalledOperand(I->second);
9640fca6ea1SDimitry Andric DirectCalledFns.insert(I->first);
9650fca6ea1SDimitry Andric continue;
9660fca6ea1SDimitry Andric }
9670fca6ea1SDimitry Andric }
9680fca6ea1SDimitry Andric
9697a6dacacSDimitry Andric IndirectCalls.push_back(CB);
9707a6dacacSDimitry Andric ++Arm64ECCallsLowered;
9717a6dacacSDimitry Andric }
9727a6dacacSDimitry Andric }
9737a6dacacSDimitry Andric
9747a6dacacSDimitry Andric if (IndirectCalls.empty())
9757a6dacacSDimitry Andric return false;
9767a6dacacSDimitry Andric
9777a6dacacSDimitry Andric for (CallBase *CB : IndirectCalls)
9787a6dacacSDimitry Andric lowerCall(CB);
9797a6dacacSDimitry Andric
9807a6dacacSDimitry Andric return true;
9817a6dacacSDimitry Andric }
9827a6dacacSDimitry Andric
9837a6dacacSDimitry Andric char AArch64Arm64ECCallLowering::ID = 0;
9847a6dacacSDimitry Andric INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering",
9857a6dacacSDimitry Andric "AArch64Arm64ECCallLowering", false, false)
9867a6dacacSDimitry Andric
createAArch64Arm64ECCallLoweringPass()9877a6dacacSDimitry Andric ModulePass *llvm::createAArch64Arm64ECCallLoweringPass() {
9887a6dacacSDimitry Andric return new AArch64Arm64ECCallLowering;
9897a6dacacSDimitry Andric }
990