1 //===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file contains the IR transform to lower external or indirect calls for 11 /// the ARM64EC calling convention. Such calls must go through the runtime, so 12 /// we can translate the calling convention for calls into the emulator. 13 /// 14 /// This subsumes Control Flow Guard handling. 15 /// 16 //===----------------------------------------------------------------------===// 17 18 #include "AArch64.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/SmallString.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include "llvm/ADT/Statistic.h" 23 #include "llvm/IR/CallingConv.h" 24 #include "llvm/IR/GlobalAlias.h" 25 #include "llvm/IR/IRBuilder.h" 26 #include "llvm/IR/Instruction.h" 27 #include "llvm/IR/Mangler.h" 28 #include "llvm/IR/Module.h" 29 #include "llvm/Object/COFF.h" 30 #include "llvm/Pass.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/TargetParser/Triple.h" 33 34 using namespace llvm; 35 using namespace llvm::COFF; 36 37 using OperandBundleDef = OperandBundleDefT<Value *>; 38 39 #define DEBUG_TYPE "arm64eccalllowering" 40 41 STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered"); 42 43 static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect", 44 cl::Hidden, cl::init(true)); 45 static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden, 46 cl::init(true)); 47 48 namespace { 49 50 enum ThunkArgTranslation : uint8_t { 51 Direct, 52 Bitcast, 53 PointerIndirection, 54 }; 55 56 struct ThunkArgInfo { 57 Type *Arm64Ty; 58 Type *X64Ty; 59 ThunkArgTranslation Translation; 60 }; 61 62 class AArch64Arm64ECCallLowering : public ModulePass { 63 public: 64 static char ID; 65 AArch64Arm64ECCallLowering() : ModulePass(ID) {} 66 67 Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs); 68 Function *buildEntryThunk(Function *F); 69 void lowerCall(CallBase *CB); 70 Function *buildGuestExitThunk(Function *F); 71 Function *buildPatchableThunk(GlobalAlias *UnmangledAlias, 72 GlobalAlias *MangledAlias); 73 bool processFunction(Function &F, SetVector<GlobalValue *> &DirectCalledFns, 74 DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap); 75 bool runOnModule(Module &M) override; 76 77 private: 78 int cfguard_module_flag = 0; 79 FunctionType *GuardFnType = nullptr; 80 FunctionType *DispatchFnType = nullptr; 81 Constant *GuardFnCFGlobal = nullptr; 82 Constant *GuardFnGlobal = nullptr; 83 Constant *DispatchFnGlobal = nullptr; 84 Module *M = nullptr; 85 86 Type *PtrTy; 87 Type *I64Ty; 88 Type *VoidTy; 89 90 void getThunkType(FunctionType *FT, AttributeList AttrList, 91 Arm64ECThunkType TT, raw_ostream &Out, 92 FunctionType *&Arm64Ty, FunctionType *&X64Ty, 93 SmallVector<ThunkArgTranslation> &ArgTranslations); 94 void getThunkRetType(FunctionType *FT, AttributeList AttrList, 95 raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy, 96 SmallVectorImpl<Type *> &Arm64ArgTypes, 97 SmallVectorImpl<Type *> &X64ArgTypes, 98 SmallVector<ThunkArgTranslation> &ArgTranslations, 99 bool &HasSretPtr); 100 void getThunkArgTypes(FunctionType *FT, AttributeList AttrList, 101 Arm64ECThunkType TT, raw_ostream &Out, 102 SmallVectorImpl<Type *> &Arm64ArgTypes, 103 SmallVectorImpl<Type *> &X64ArgTypes, 104 SmallVectorImpl<ThunkArgTranslation> &ArgTranslations, 105 bool HasSretPtr); 106 ThunkArgInfo canonicalizeThunkType(Type *T, Align Alignment, bool Ret, 107 uint64_t ArgSizeBytes, raw_ostream &Out); 108 }; 109 110 } // end anonymous namespace 111 112 void AArch64Arm64ECCallLowering::getThunkType( 113 FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT, 114 raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty, 115 SmallVector<ThunkArgTranslation> &ArgTranslations) { 116 Out << (TT == Arm64ECThunkType::Entry ? "$ientry_thunk$cdecl$" 117 : "$iexit_thunk$cdecl$"); 118 119 Type *Arm64RetTy; 120 Type *X64RetTy; 121 122 SmallVector<Type *> Arm64ArgTypes; 123 SmallVector<Type *> X64ArgTypes; 124 125 // The first argument to a thunk is the called function, stored in x9. 126 // For exit thunks, we pass the called function down to the emulator; 127 // for entry/guest exit thunks, we just call the Arm64 function directly. 128 if (TT == Arm64ECThunkType::Exit) 129 Arm64ArgTypes.push_back(PtrTy); 130 X64ArgTypes.push_back(PtrTy); 131 132 bool HasSretPtr = false; 133 getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes, 134 X64ArgTypes, ArgTranslations, HasSretPtr); 135 136 getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes, 137 ArgTranslations, HasSretPtr); 138 139 Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false); 140 141 X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false); 142 } 143 144 void AArch64Arm64ECCallLowering::getThunkArgTypes( 145 FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT, 146 raw_ostream &Out, SmallVectorImpl<Type *> &Arm64ArgTypes, 147 SmallVectorImpl<Type *> &X64ArgTypes, 148 SmallVectorImpl<ThunkArgTranslation> &ArgTranslations, bool HasSretPtr) { 149 150 Out << "$"; 151 if (FT->isVarArg()) { 152 // We treat the variadic function's thunk as a normal function 153 // with the following type on the ARM side: 154 // rettype exitthunk( 155 // ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5) 156 // 157 // that can coverage all types of variadic function. 158 // x9 is similar to normal exit thunk, store the called function. 159 // x0-x3 is the arguments be stored in registers. 160 // x4 is the address of the arguments on the stack. 161 // x5 is the size of the arguments on the stack. 162 // 163 // On the x64 side, it's the same except that x5 isn't set. 164 // 165 // If both the ARM and X64 sides are sret, there are only three 166 // arguments in registers. 167 // 168 // If the X64 side is sret, but the ARM side isn't, we pass an extra value 169 // to/from the X64 side, and let SelectionDAG transform it into a memory 170 // location. 171 Out << "varargs"; 172 173 // x0-x3 174 for (int i = HasSretPtr ? 1 : 0; i < 4; i++) { 175 Arm64ArgTypes.push_back(I64Ty); 176 X64ArgTypes.push_back(I64Ty); 177 ArgTranslations.push_back(ThunkArgTranslation::Direct); 178 } 179 180 // x4 181 Arm64ArgTypes.push_back(PtrTy); 182 X64ArgTypes.push_back(PtrTy); 183 ArgTranslations.push_back(ThunkArgTranslation::Direct); 184 // x5 185 Arm64ArgTypes.push_back(I64Ty); 186 if (TT != Arm64ECThunkType::Entry) { 187 // FIXME: x5 isn't actually used by the x64 side; revisit once we 188 // have proper isel for varargs 189 X64ArgTypes.push_back(I64Ty); 190 ArgTranslations.push_back(ThunkArgTranslation::Direct); 191 } 192 return; 193 } 194 195 unsigned I = 0; 196 if (HasSretPtr) 197 I++; 198 199 if (I == FT->getNumParams()) { 200 Out << "v"; 201 return; 202 } 203 204 for (unsigned E = FT->getNumParams(); I != E; ++I) { 205 #if 0 206 // FIXME: Need more information about argument size; see 207 // https://reviews.llvm.org/D132926 208 uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I); 209 Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne(); 210 #else 211 uint64_t ArgSizeBytes = 0; 212 Align ParamAlign = Align(); 213 #endif 214 auto [Arm64Ty, X64Ty, ArgTranslation] = 215 canonicalizeThunkType(FT->getParamType(I), ParamAlign, 216 /*Ret*/ false, ArgSizeBytes, Out); 217 Arm64ArgTypes.push_back(Arm64Ty); 218 X64ArgTypes.push_back(X64Ty); 219 ArgTranslations.push_back(ArgTranslation); 220 } 221 } 222 223 void AArch64Arm64ECCallLowering::getThunkRetType( 224 FunctionType *FT, AttributeList AttrList, raw_ostream &Out, 225 Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes, 226 SmallVectorImpl<Type *> &X64ArgTypes, 227 SmallVector<ThunkArgTranslation> &ArgTranslations, bool &HasSretPtr) { 228 Type *T = FT->getReturnType(); 229 #if 0 230 // FIXME: Need more information about argument size; see 231 // https://reviews.llvm.org/D132926 232 uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes(); 233 #else 234 int64_t ArgSizeBytes = 0; 235 #endif 236 if (T->isVoidTy()) { 237 if (FT->getNumParams()) { 238 Attribute SRetAttr0 = AttrList.getParamAttr(0, Attribute::StructRet); 239 Attribute InRegAttr0 = AttrList.getParamAttr(0, Attribute::InReg); 240 Attribute SRetAttr1, InRegAttr1; 241 if (FT->getNumParams() > 1) { 242 // Also check the second parameter (for class methods, the first 243 // parameter is "this", and the second parameter is the sret pointer.) 244 // It doesn't matter which one is sret. 245 SRetAttr1 = AttrList.getParamAttr(1, Attribute::StructRet); 246 InRegAttr1 = AttrList.getParamAttr(1, Attribute::InReg); 247 } 248 if ((SRetAttr0.isValid() && InRegAttr0.isValid()) || 249 (SRetAttr1.isValid() && InRegAttr1.isValid())) { 250 // sret+inreg indicates a call that returns a C++ class value. This is 251 // actually equivalent to just passing and returning a void* pointer 252 // as the first or second argument. Translate it that way, instead of 253 // trying to model "inreg" in the thunk's calling convention; this 254 // simplfies the rest of the code, and matches MSVC mangling. 255 Out << "i8"; 256 Arm64RetTy = I64Ty; 257 X64RetTy = I64Ty; 258 return; 259 } 260 if (SRetAttr0.isValid()) { 261 // FIXME: Sanity-check the sret type; if it's an integer or pointer, 262 // we'll get screwy mangling/codegen. 263 // FIXME: For large struct types, mangle as an integer argument and 264 // integer return, so we can reuse more thunks, instead of "m" syntax. 265 // (MSVC mangles this case as an integer return with no argument, but 266 // that's a miscompile.) 267 Type *SRetType = SRetAttr0.getValueAsType(); 268 Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne(); 269 canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes, 270 Out); 271 Arm64RetTy = VoidTy; 272 X64RetTy = VoidTy; 273 Arm64ArgTypes.push_back(FT->getParamType(0)); 274 X64ArgTypes.push_back(FT->getParamType(0)); 275 ArgTranslations.push_back(ThunkArgTranslation::Direct); 276 HasSretPtr = true; 277 return; 278 } 279 } 280 281 Out << "v"; 282 Arm64RetTy = VoidTy; 283 X64RetTy = VoidTy; 284 return; 285 } 286 287 auto info = 288 canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out); 289 Arm64RetTy = info.Arm64Ty; 290 X64RetTy = info.X64Ty; 291 if (X64RetTy->isPointerTy()) { 292 // If the X64 type is canonicalized to a pointer, that means it's 293 // passed/returned indirectly. For a return value, that means it's an 294 // sret pointer. 295 X64ArgTypes.push_back(X64RetTy); 296 X64RetTy = VoidTy; 297 } 298 } 299 300 ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType( 301 Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes, 302 raw_ostream &Out) { 303 304 auto direct = [](Type *T) { 305 return ThunkArgInfo{T, T, ThunkArgTranslation::Direct}; 306 }; 307 308 auto bitcast = [this](Type *Arm64Ty, uint64_t SizeInBytes) { 309 return ThunkArgInfo{Arm64Ty, 310 llvm::Type::getIntNTy(M->getContext(), SizeInBytes * 8), 311 ThunkArgTranslation::Bitcast}; 312 }; 313 314 auto pointerIndirection = [this](Type *Arm64Ty) { 315 return ThunkArgInfo{Arm64Ty, PtrTy, 316 ThunkArgTranslation::PointerIndirection}; 317 }; 318 319 if (T->isFloatTy()) { 320 Out << "f"; 321 return direct(T); 322 } 323 324 if (T->isDoubleTy()) { 325 Out << "d"; 326 return direct(T); 327 } 328 329 if (T->isFloatingPointTy()) { 330 report_fatal_error( 331 "Only 32 and 64 bit floating points are supported for ARM64EC thunks"); 332 } 333 334 auto &DL = M->getDataLayout(); 335 336 if (auto *StructTy = dyn_cast<StructType>(T)) 337 if (StructTy->getNumElements() == 1) 338 T = StructTy->getElementType(0); 339 340 if (T->isArrayTy()) { 341 Type *ElementTy = T->getArrayElementType(); 342 uint64_t ElementCnt = T->getArrayNumElements(); 343 uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8; 344 uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes; 345 if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) { 346 Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes; 347 if (Alignment.value() >= 16 && !Ret) 348 Out << "a" << Alignment.value(); 349 if (TotalSizeBytes <= 8) { 350 // Arm64 returns small structs of float/double in float registers; 351 // X64 uses RAX. 352 return bitcast(T, TotalSizeBytes); 353 } else { 354 // Struct is passed directly on Arm64, but indirectly on X64. 355 return pointerIndirection(T); 356 } 357 } else if (T->isFloatingPointTy()) { 358 report_fatal_error("Only 32 and 64 bit floating points are supported for " 359 "ARM64EC thunks"); 360 } 361 } 362 363 if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) { 364 Out << "i8"; 365 return direct(I64Ty); 366 } 367 368 unsigned TypeSize = ArgSizeBytes; 369 if (TypeSize == 0) 370 TypeSize = DL.getTypeSizeInBits(T) / 8; 371 Out << "m"; 372 if (TypeSize != 4) 373 Out << TypeSize; 374 if (Alignment.value() >= 16 && !Ret) 375 Out << "a" << Alignment.value(); 376 // FIXME: Try to canonicalize Arm64Ty more thoroughly? 377 if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) { 378 // Pass directly in an integer register 379 return bitcast(T, TypeSize); 380 } else { 381 // Passed directly on Arm64, but indirectly on X64. 382 return pointerIndirection(T); 383 } 384 } 385 386 // This function builds the "exit thunk", a function which translates 387 // arguments and return values when calling x64 code from AArch64 code. 388 Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT, 389 AttributeList Attrs) { 390 SmallString<256> ExitThunkName; 391 llvm::raw_svector_ostream ExitThunkStream(ExitThunkName); 392 FunctionType *Arm64Ty, *X64Ty; 393 SmallVector<ThunkArgTranslation> ArgTranslations; 394 getThunkType(FT, Attrs, Arm64ECThunkType::Exit, ExitThunkStream, Arm64Ty, 395 X64Ty, ArgTranslations); 396 if (Function *F = M->getFunction(ExitThunkName)) 397 return F; 398 399 Function *F = Function::Create(Arm64Ty, GlobalValue::LinkOnceODRLinkage, 0, 400 ExitThunkName, M); 401 F->setCallingConv(CallingConv::ARM64EC_Thunk_Native); 402 F->setSection(".wowthk$aa"); 403 F->setComdat(M->getOrInsertComdat(ExitThunkName)); 404 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.) 405 F->addFnAttr("frame-pointer", "all"); 406 // Only copy sret from the first argument. For C++ instance methods, clang can 407 // stick an sret marking on a later argument, but it doesn't actually affect 408 // the ABI, so we can omit it. This avoids triggering a verifier assertion. 409 if (FT->getNumParams()) { 410 auto SRet = Attrs.getParamAttr(0, Attribute::StructRet); 411 auto InReg = Attrs.getParamAttr(0, Attribute::InReg); 412 if (SRet.isValid() && !InReg.isValid()) 413 F->addParamAttr(1, SRet); 414 } 415 // FIXME: Copy anything other than sret? Shouldn't be necessary for normal 416 // C ABI, but might show up in other cases. 417 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F); 418 IRBuilder<> IRB(BB); 419 Value *CalleePtr = 420 M->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy); 421 Value *Callee = IRB.CreateLoad(PtrTy, CalleePtr); 422 auto &DL = M->getDataLayout(); 423 SmallVector<Value *> Args; 424 425 // Pass the called function in x9. 426 auto X64TyOffset = 1; 427 Args.push_back(F->arg_begin()); 428 429 Type *RetTy = Arm64Ty->getReturnType(); 430 if (RetTy != X64Ty->getReturnType()) { 431 // If the return type is an array or struct, translate it. Values of size 432 // 8 or less go into RAX; bigger values go into memory, and we pass a 433 // pointer. 434 if (DL.getTypeStoreSize(RetTy) > 8) { 435 Args.push_back(IRB.CreateAlloca(RetTy)); 436 X64TyOffset++; 437 } 438 } 439 440 for (auto [Arg, X64ArgType, ArgTranslation] : llvm::zip_equal( 441 make_range(F->arg_begin() + 1, F->arg_end()), 442 make_range(X64Ty->param_begin() + X64TyOffset, X64Ty->param_end()), 443 ArgTranslations)) { 444 // Translate arguments from AArch64 calling convention to x86 calling 445 // convention. 446 // 447 // For simple types, we don't need to do any translation: they're 448 // represented the same way. (Implicit sign extension is not part of 449 // either convention.) 450 // 451 // The big thing we have to worry about is struct types... but 452 // fortunately AArch64 clang is pretty friendly here: the cases that need 453 // translation are always passed as a struct or array. (If we run into 454 // some cases where this doesn't work, we can teach clang to mark it up 455 // with an attribute.) 456 // 457 // The first argument is the called function, stored in x9. 458 if (ArgTranslation != ThunkArgTranslation::Direct) { 459 Value *Mem = IRB.CreateAlloca(Arg.getType()); 460 IRB.CreateStore(&Arg, Mem); 461 if (ArgTranslation == ThunkArgTranslation::Bitcast) { 462 Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType())); 463 Args.push_back(IRB.CreateLoad(IntTy, Mem)); 464 } else { 465 assert(ArgTranslation == ThunkArgTranslation::PointerIndirection); 466 Args.push_back(Mem); 467 } 468 } else { 469 Args.push_back(&Arg); 470 } 471 assert(Args.back()->getType() == X64ArgType); 472 } 473 // FIXME: Transfer necessary attributes? sret? anything else? 474 475 CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args); 476 Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64); 477 478 Value *RetVal = Call; 479 if (RetTy != X64Ty->getReturnType()) { 480 // If we rewrote the return type earlier, convert the return value to 481 // the proper type. 482 if (DL.getTypeStoreSize(RetTy) > 8) { 483 RetVal = IRB.CreateLoad(RetTy, Args[1]); 484 } else { 485 Value *CastAlloca = IRB.CreateAlloca(RetTy); 486 IRB.CreateStore(Call, CastAlloca); 487 RetVal = IRB.CreateLoad(RetTy, CastAlloca); 488 } 489 } 490 491 if (RetTy->isVoidTy()) 492 IRB.CreateRetVoid(); 493 else 494 IRB.CreateRet(RetVal); 495 return F; 496 } 497 498 // This function builds the "entry thunk", a function which translates 499 // arguments and return values when calling AArch64 code from x64 code. 500 Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) { 501 SmallString<256> EntryThunkName; 502 llvm::raw_svector_ostream EntryThunkStream(EntryThunkName); 503 FunctionType *Arm64Ty, *X64Ty; 504 SmallVector<ThunkArgTranslation> ArgTranslations; 505 getThunkType(F->getFunctionType(), F->getAttributes(), 506 Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty, 507 ArgTranslations); 508 if (Function *F = M->getFunction(EntryThunkName)) 509 return F; 510 511 Function *Thunk = Function::Create(X64Ty, GlobalValue::LinkOnceODRLinkage, 0, 512 EntryThunkName, M); 513 Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64); 514 Thunk->setSection(".wowthk$aa"); 515 Thunk->setComdat(M->getOrInsertComdat(EntryThunkName)); 516 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.) 517 Thunk->addFnAttr("frame-pointer", "all"); 518 519 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk); 520 IRBuilder<> IRB(BB); 521 522 Type *RetTy = Arm64Ty->getReturnType(); 523 Type *X64RetType = X64Ty->getReturnType(); 524 525 bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy(); 526 unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1; 527 unsigned PassthroughArgSize = 528 (F->isVarArg() ? 5 : Thunk->arg_size()) - ThunkArgOffset; 529 assert(ArgTranslations.size() == (F->isVarArg() ? 5 : PassthroughArgSize)); 530 531 // Translate arguments to call. 532 SmallVector<Value *> Args; 533 for (unsigned i = 0; i != PassthroughArgSize; ++i) { 534 Value *Arg = Thunk->getArg(i + ThunkArgOffset); 535 Type *ArgTy = Arm64Ty->getParamType(i); 536 ThunkArgTranslation ArgTranslation = ArgTranslations[i]; 537 if (ArgTranslation != ThunkArgTranslation::Direct) { 538 // Translate array/struct arguments to the expected type. 539 if (ArgTranslation == ThunkArgTranslation::Bitcast) { 540 Value *CastAlloca = IRB.CreateAlloca(ArgTy); 541 IRB.CreateStore(Arg, CastAlloca); 542 Arg = IRB.CreateLoad(ArgTy, CastAlloca); 543 } else { 544 assert(ArgTranslation == ThunkArgTranslation::PointerIndirection); 545 Arg = IRB.CreateLoad(ArgTy, Arg); 546 } 547 } 548 assert(Arg->getType() == ArgTy); 549 Args.push_back(Arg); 550 } 551 552 if (F->isVarArg()) { 553 // The 5th argument to variadic entry thunks is used to model the x64 sp 554 // which is passed to the thunk in x4, this can be passed to the callee as 555 // the variadic argument start address after skipping over the 32 byte 556 // shadow store. 557 558 // The EC thunk CC will assign any argument marked as InReg to x4. 559 Thunk->addParamAttr(5, Attribute::InReg); 560 Value *Arg = Thunk->getArg(5); 561 Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20)); 562 Args.push_back(Arg); 563 564 // Pass in a zero variadic argument size (in x5). 565 Args.push_back(IRB.getInt64(0)); 566 } 567 568 // Call the function passed to the thunk. 569 Value *Callee = Thunk->getArg(0); 570 CallInst *Call = IRB.CreateCall(Arm64Ty, Callee, Args); 571 572 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet); 573 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg); 574 if (SRetAttr.isValid() && !InRegAttr.isValid()) { 575 Thunk->addParamAttr(1, SRetAttr); 576 Call->addParamAttr(0, SRetAttr); 577 } 578 579 Value *RetVal = Call; 580 if (TransformDirectToSRet) { 581 IRB.CreateStore(RetVal, Thunk->getArg(1)); 582 } else if (X64RetType != RetTy) { 583 Value *CastAlloca = IRB.CreateAlloca(X64RetType); 584 IRB.CreateStore(Call, CastAlloca); 585 RetVal = IRB.CreateLoad(X64RetType, CastAlloca); 586 } 587 588 // Return to the caller. Note that the isel has code to translate this 589 // "ret" to a tail call to __os_arm64x_dispatch_ret. (Alternatively, we 590 // could emit a tail call here, but that would require a dedicated calling 591 // convention, which seems more complicated overall.) 592 if (X64RetType->isVoidTy()) 593 IRB.CreateRetVoid(); 594 else 595 IRB.CreateRet(RetVal); 596 597 return Thunk; 598 } 599 600 // Builds the "guest exit thunk", a helper to call a function which may or may 601 // not be an exit thunk. (We optimistically assume non-dllimport function 602 // declarations refer to functions defined in AArch64 code; if the linker 603 // can't prove that, we use this routine instead.) 604 Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) { 605 llvm::raw_null_ostream NullThunkName; 606 FunctionType *Arm64Ty, *X64Ty; 607 SmallVector<ThunkArgTranslation> ArgTranslations; 608 getThunkType(F->getFunctionType(), F->getAttributes(), 609 Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty, 610 ArgTranslations); 611 auto MangledName = getArm64ECMangledFunctionName(F->getName().str()); 612 assert(MangledName && "Can't guest exit to function that's already native"); 613 std::string ThunkName = *MangledName; 614 if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) { 615 ThunkName.insert(ThunkName.find("@"), "$exit_thunk"); 616 } else { 617 ThunkName.append("$exit_thunk"); 618 } 619 Function *GuestExit = 620 Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M); 621 GuestExit->setComdat(M->getOrInsertComdat(ThunkName)); 622 GuestExit->setSection(".wowthk$aa"); 623 GuestExit->addMetadata( 624 "arm64ec_unmangled_name", 625 *MDNode::get(M->getContext(), 626 MDString::get(M->getContext(), F->getName()))); 627 GuestExit->setMetadata( 628 "arm64ec_ecmangled_name", 629 MDNode::get(M->getContext(), 630 MDString::get(M->getContext(), *MangledName))); 631 F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {})); 632 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit); 633 IRBuilder<> B(BB); 634 635 // Load the global symbol as a pointer to the check function. 636 Value *GuardFn; 637 if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf")) 638 GuardFn = GuardFnCFGlobal; 639 else 640 GuardFn = GuardFnGlobal; 641 LoadInst *GuardCheckLoad = B.CreateLoad(PtrTy, GuardFn); 642 643 // Create new call instruction. The CFGuard check should always be a call, 644 // even if the original CallBase is an Invoke or CallBr instruction. 645 Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes()); 646 CallInst *GuardCheck = B.CreateCall( 647 GuardFnType, GuardCheckLoad, {F, Thunk}); 648 649 // Ensure that the first argument is passed in the correct register. 650 GuardCheck->setCallingConv(CallingConv::CFGuard_Check); 651 652 SmallVector<Value *> Args(llvm::make_pointer_range(GuestExit->args())); 653 CallInst *Call = B.CreateCall(Arm64Ty, GuardCheck, Args); 654 Call->setTailCallKind(llvm::CallInst::TCK_MustTail); 655 656 if (Call->getType()->isVoidTy()) 657 B.CreateRetVoid(); 658 else 659 B.CreateRet(Call); 660 661 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet); 662 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg); 663 if (SRetAttr.isValid() && !InRegAttr.isValid()) { 664 GuestExit->addParamAttr(0, SRetAttr); 665 Call->addParamAttr(0, SRetAttr); 666 } 667 668 return GuestExit; 669 } 670 671 Function * 672 AArch64Arm64ECCallLowering::buildPatchableThunk(GlobalAlias *UnmangledAlias, 673 GlobalAlias *MangledAlias) { 674 llvm::raw_null_ostream NullThunkName; 675 FunctionType *Arm64Ty, *X64Ty; 676 Function *F = cast<Function>(MangledAlias->getAliasee()); 677 SmallVector<ThunkArgTranslation> ArgTranslations; 678 getThunkType(F->getFunctionType(), F->getAttributes(), 679 Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty, 680 ArgTranslations); 681 std::string ThunkName(MangledAlias->getName()); 682 if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) { 683 ThunkName.insert(ThunkName.find("@"), "$hybpatch_thunk"); 684 } else { 685 ThunkName.append("$hybpatch_thunk"); 686 } 687 688 Function *GuestExit = 689 Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M); 690 GuestExit->setComdat(M->getOrInsertComdat(ThunkName)); 691 GuestExit->setSection(".wowthk$aa"); 692 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit); 693 IRBuilder<> B(BB); 694 695 // Load the global symbol as a pointer to the check function. 696 LoadInst *DispatchLoad = B.CreateLoad(PtrTy, DispatchFnGlobal); 697 698 // Create new dispatch call instruction. 699 Function *ExitThunk = 700 buildExitThunk(F->getFunctionType(), F->getAttributes()); 701 CallInst *Dispatch = 702 B.CreateCall(DispatchFnType, DispatchLoad, 703 {UnmangledAlias, ExitThunk, UnmangledAlias->getAliasee()}); 704 705 // Ensure that the first arguments are passed in the correct registers. 706 Dispatch->setCallingConv(CallingConv::CFGuard_Check); 707 708 SmallVector<Value *> Args(llvm::make_pointer_range(GuestExit->args())); 709 CallInst *Call = B.CreateCall(Arm64Ty, Dispatch, Args); 710 Call->setTailCallKind(llvm::CallInst::TCK_MustTail); 711 712 if (Call->getType()->isVoidTy()) 713 B.CreateRetVoid(); 714 else 715 B.CreateRet(Call); 716 717 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet); 718 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg); 719 if (SRetAttr.isValid() && !InRegAttr.isValid()) { 720 GuestExit->addParamAttr(0, SRetAttr); 721 Call->addParamAttr(0, SRetAttr); 722 } 723 724 MangledAlias->setAliasee(GuestExit); 725 return GuestExit; 726 } 727 728 // Lower an indirect call with inline code. 729 void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) { 730 assert(CB->getModule()->getTargetTriple().isOSWindows() && 731 "Only applicable for Windows targets"); 732 733 IRBuilder<> B(CB); 734 Value *CalledOperand = CB->getCalledOperand(); 735 736 // If the indirect call is called within catchpad or cleanuppad, 737 // we need to copy "funclet" bundle of the call. 738 SmallVector<llvm::OperandBundleDef, 1> Bundles; 739 if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet)) 740 Bundles.push_back(OperandBundleDef(*Bundle)); 741 742 // Load the global symbol as a pointer to the check function. 743 Value *GuardFn; 744 if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf")) 745 GuardFn = GuardFnCFGlobal; 746 else 747 GuardFn = GuardFnGlobal; 748 LoadInst *GuardCheckLoad = B.CreateLoad(PtrTy, GuardFn); 749 750 // Create new call instruction. The CFGuard check should always be a call, 751 // even if the original CallBase is an Invoke or CallBr instruction. 752 Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes()); 753 CallInst *GuardCheck = 754 B.CreateCall(GuardFnType, GuardCheckLoad, {CalledOperand, Thunk}, 755 Bundles); 756 757 // Ensure that the first argument is passed in the correct register. 758 GuardCheck->setCallingConv(CallingConv::CFGuard_Check); 759 760 CB->setCalledOperand(GuardCheck); 761 } 762 763 bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) { 764 if (!GenerateThunks) 765 return false; 766 767 M = &Mod; 768 769 // Check if this module has the cfguard flag and read its value. 770 if (auto *MD = 771 mdconst::extract_or_null<ConstantInt>(M->getModuleFlag("cfguard"))) 772 cfguard_module_flag = MD->getZExtValue(); 773 774 PtrTy = PointerType::getUnqual(M->getContext()); 775 I64Ty = Type::getInt64Ty(M->getContext()); 776 VoidTy = Type::getVoidTy(M->getContext()); 777 778 GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false); 779 DispatchFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy, PtrTy}, false); 780 GuardFnCFGlobal = M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", PtrTy); 781 GuardFnGlobal = M->getOrInsertGlobal("__os_arm64x_check_icall", PtrTy); 782 DispatchFnGlobal = M->getOrInsertGlobal("__os_arm64x_dispatch_call", PtrTy); 783 784 // Mangle names of function aliases and add the alias name to 785 // arm64ec_unmangled_name metadata to ensure a weak anti-dependency symbol is 786 // emitted for the alias as well. Do this early, before handling 787 // hybrid_patchable functions, to avoid mangling their aliases. 788 for (GlobalAlias &A : Mod.aliases()) { 789 auto F = dyn_cast_or_null<Function>(A.getAliaseeObject()); 790 if (!F) 791 continue; 792 if (std::optional<std::string> MangledName = 793 getArm64ECMangledFunctionName(A.getName().str())) { 794 F->addMetadata("arm64ec_unmangled_name", 795 *MDNode::get(M->getContext(), 796 MDString::get(M->getContext(), A.getName()))); 797 A.setName(MangledName.value()); 798 } 799 } 800 801 DenseMap<GlobalAlias *, GlobalAlias *> FnsMap; 802 SetVector<GlobalAlias *> PatchableFns; 803 804 for (Function &F : Mod) { 805 if (F.hasPersonalityFn()) { 806 GlobalValue *PersFn = 807 cast<GlobalValue>(F.getPersonalityFn()->stripPointerCasts()); 808 if (PersFn->getValueType() && PersFn->getValueType()->isFunctionTy()) { 809 if (std::optional<std::string> MangledName = 810 getArm64ECMangledFunctionName(PersFn->getName().str())) { 811 PersFn->setName(MangledName.value()); 812 } 813 } 814 } 815 816 if (!F.hasFnAttribute(Attribute::HybridPatchable) || F.isDeclaration() || 817 F.hasLocalLinkage() || 818 F.getName().ends_with(HybridPatchableTargetSuffix)) 819 continue; 820 821 // Rename hybrid patchable functions and change callers to use a global 822 // alias instead. 823 if (std::optional<std::string> MangledName = 824 getArm64ECMangledFunctionName(F.getName().str())) { 825 std::string OrigName(F.getName()); 826 F.setName(MangledName.value() + HybridPatchableTargetSuffix); 827 828 // The unmangled symbol is a weak alias to an undefined symbol with the 829 // "EXP+" prefix. This undefined symbol is resolved by the linker by 830 // creating an x86 thunk that jumps back to the actual EC target. Since we 831 // can't represent that in IR, we create an alias to the target instead. 832 // The "EXP+" symbol is set as metadata, which is then used by 833 // emitGlobalAlias to emit the right alias. 834 auto *A = 835 GlobalAlias::create(GlobalValue::LinkOnceODRLinkage, OrigName, &F); 836 auto *AM = GlobalAlias::create(GlobalValue::LinkOnceODRLinkage, 837 MangledName.value(), &F); 838 F.replaceUsesWithIf(AM, 839 [](Use &U) { return isa<GlobalAlias>(U.getUser()); }); 840 F.replaceAllUsesWith(A); 841 F.setMetadata("arm64ec_exp_name", 842 MDNode::get(M->getContext(), 843 MDString::get(M->getContext(), 844 "EXP+" + MangledName.value()))); 845 A->setAliasee(&F); 846 AM->setAliasee(&F); 847 848 if (F.hasDLLExportStorageClass()) { 849 A->setDLLStorageClass(GlobalValue::DLLExportStorageClass); 850 F.setDLLStorageClass(GlobalValue::DefaultStorageClass); 851 } 852 853 FnsMap[A] = AM; 854 PatchableFns.insert(A); 855 } 856 } 857 858 SetVector<GlobalValue *> DirectCalledFns; 859 for (Function &F : Mod) 860 if (!F.isDeclaration() && 861 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native && 862 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) 863 processFunction(F, DirectCalledFns, FnsMap); 864 865 struct ThunkInfo { 866 Constant *Src; 867 Constant *Dst; 868 Arm64ECThunkType Kind; 869 }; 870 SmallVector<ThunkInfo> ThunkMapping; 871 for (Function &F : Mod) { 872 if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) && 873 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native && 874 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) { 875 if (!F.hasComdat()) 876 F.setComdat(Mod.getOrInsertComdat(F.getName())); 877 ThunkMapping.push_back( 878 {&F, buildEntryThunk(&F), Arm64ECThunkType::Entry}); 879 } 880 } 881 for (GlobalValue *O : DirectCalledFns) { 882 auto GA = dyn_cast<GlobalAlias>(O); 883 auto F = dyn_cast<Function>(GA ? GA->getAliasee() : O); 884 ThunkMapping.push_back( 885 {O, buildExitThunk(F->getFunctionType(), F->getAttributes()), 886 Arm64ECThunkType::Exit}); 887 if (!GA && !F->hasDLLImportStorageClass()) 888 ThunkMapping.push_back( 889 {buildGuestExitThunk(F), F, Arm64ECThunkType::GuestExit}); 890 } 891 for (GlobalAlias *A : PatchableFns) { 892 Function *Thunk = buildPatchableThunk(A, FnsMap[A]); 893 ThunkMapping.push_back({Thunk, A, Arm64ECThunkType::GuestExit}); 894 } 895 896 if (!ThunkMapping.empty()) { 897 SmallVector<Constant *> ThunkMappingArrayElems; 898 for (ThunkInfo &Thunk : ThunkMapping) { 899 ThunkMappingArrayElems.push_back(ConstantStruct::getAnon( 900 {Thunk.Src, Thunk.Dst, 901 ConstantInt::get(M->getContext(), APInt(32, uint8_t(Thunk.Kind)))})); 902 } 903 Constant *ThunkMappingArray = ConstantArray::get( 904 llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(), 905 ThunkMappingArrayElems.size()), 906 ThunkMappingArrayElems); 907 new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false, 908 GlobalValue::ExternalLinkage, ThunkMappingArray, 909 "llvm.arm64ec.symbolmap"); 910 } 911 912 return true; 913 } 914 915 bool AArch64Arm64ECCallLowering::processFunction( 916 Function &F, SetVector<GlobalValue *> &DirectCalledFns, 917 DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap) { 918 SmallVector<CallBase *, 8> IndirectCalls; 919 920 // For ARM64EC targets, a function definition's name is mangled differently 921 // from the normal symbol. We currently have no representation of this sort 922 // of symbol in IR, so we change the name to the mangled name, then store 923 // the unmangled name as metadata. Later passes that need the unmangled 924 // name (emitting the definition) can grab it from the metadata. 925 // 926 // FIXME: Handle functions with weak linkage? 927 if (!F.hasLocalLinkage() || F.hasAddressTaken()) { 928 if (std::optional<std::string> MangledName = 929 getArm64ECMangledFunctionName(F.getName().str())) { 930 F.addMetadata("arm64ec_unmangled_name", 931 *MDNode::get(M->getContext(), 932 MDString::get(M->getContext(), F.getName()))); 933 if (F.hasComdat() && F.getComdat()->getName() == F.getName()) { 934 Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value()); 935 SmallVector<GlobalObject *> ComdatUsers = 936 to_vector(F.getComdat()->getUsers()); 937 for (GlobalObject *User : ComdatUsers) 938 User->setComdat(MangledComdat); 939 } 940 F.setName(MangledName.value()); 941 } 942 } 943 944 // Iterate over the instructions to find all indirect call/invoke/callbr 945 // instructions. Make a separate list of pointers to indirect 946 // call/invoke/callbr instructions because the original instructions will be 947 // deleted as the checks are added. 948 for (BasicBlock &BB : F) { 949 for (Instruction &I : BB) { 950 auto *CB = dyn_cast<CallBase>(&I); 951 if (!CB || CB->getCallingConv() == CallingConv::ARM64EC_Thunk_X64 || 952 CB->isInlineAsm()) 953 continue; 954 955 // We need to instrument any call that isn't directly calling an 956 // ARM64 function. 957 // 958 // FIXME: getCalledFunction() fails if there's a bitcast (e.g. 959 // unprototyped functions in C) 960 if (Function *F = CB->getCalledFunction()) { 961 if (!LowerDirectToIndirect || F->hasLocalLinkage() || 962 F->isIntrinsic() || !F->isDeclaration()) 963 continue; 964 965 DirectCalledFns.insert(F); 966 continue; 967 } 968 969 // Use mangled global alias for direct calls to patchable functions. 970 if (GlobalAlias *A = dyn_cast<GlobalAlias>(CB->getCalledOperand())) { 971 auto I = FnsMap.find(A); 972 if (I != FnsMap.end()) { 973 CB->setCalledOperand(I->second); 974 DirectCalledFns.insert(I->first); 975 continue; 976 } 977 } 978 979 IndirectCalls.push_back(CB); 980 ++Arm64ECCallsLowered; 981 } 982 } 983 984 if (IndirectCalls.empty()) 985 return false; 986 987 for (CallBase *CB : IndirectCalls) 988 lowerCall(CB); 989 990 return true; 991 } 992 993 char AArch64Arm64ECCallLowering::ID = 0; 994 INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering", 995 "AArch64Arm64ECCallLowering", false, false) 996 997 ModulePass *llvm::createAArch64Arm64ECCallLoweringPass() { 998 return new AArch64Arm64ECCallLowering; 999 } 1000