1 //===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file contains the IR transform to lower external or indirect calls for 11 /// the ARM64EC calling convention. Such calls must go through the runtime, so 12 /// we can translate the calling convention for calls into the emulator. 13 /// 14 /// This subsumes Control Flow Guard handling. 15 /// 16 //===----------------------------------------------------------------------===// 17 18 #include "AArch64.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/SmallString.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include "llvm/ADT/Statistic.h" 23 #include "llvm/IR/CallingConv.h" 24 #include "llvm/IR/GlobalAlias.h" 25 #include "llvm/IR/IRBuilder.h" 26 #include "llvm/IR/Instruction.h" 27 #include "llvm/IR/Mangler.h" 28 #include "llvm/IR/Module.h" 29 #include "llvm/InitializePasses.h" 30 #include "llvm/Object/COFF.h" 31 #include "llvm/Pass.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/TargetParser/Triple.h" 34 35 using namespace llvm; 36 using namespace llvm::COFF; 37 38 using OperandBundleDef = OperandBundleDefT<Value *>; 39 40 #define DEBUG_TYPE "arm64eccalllowering" 41 42 STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered"); 43 44 static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect", 45 cl::Hidden, cl::init(true)); 46 static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden, 47 cl::init(true)); 48 49 namespace { 50 51 enum ThunkArgTranslation : uint8_t { 52 Direct, 53 Bitcast, 54 PointerIndirection, 55 }; 56 57 struct ThunkArgInfo { 58 Type *Arm64Ty; 59 Type *X64Ty; 60 ThunkArgTranslation Translation; 61 }; 62 63 class AArch64Arm64ECCallLowering : public ModulePass { 64 public: 65 static char ID; 66 AArch64Arm64ECCallLowering() : ModulePass(ID) { 67 initializeAArch64Arm64ECCallLoweringPass(*PassRegistry::getPassRegistry()); 68 } 69 70 Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs); 71 Function *buildEntryThunk(Function *F); 72 void lowerCall(CallBase *CB); 73 Function *buildGuestExitThunk(Function *F); 74 Function *buildPatchableThunk(GlobalAlias *UnmangledAlias, 75 GlobalAlias *MangledAlias); 76 bool processFunction(Function &F, SetVector<GlobalValue *> &DirectCalledFns, 77 DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap); 78 bool runOnModule(Module &M) override; 79 80 private: 81 int cfguard_module_flag = 0; 82 FunctionType *GuardFnType = nullptr; 83 PointerType *GuardFnPtrType = nullptr; 84 FunctionType *DispatchFnType = nullptr; 85 PointerType *DispatchFnPtrType = nullptr; 86 Constant *GuardFnCFGlobal = nullptr; 87 Constant *GuardFnGlobal = nullptr; 88 Constant *DispatchFnGlobal = nullptr; 89 Module *M = nullptr; 90 91 Type *PtrTy; 92 Type *I64Ty; 93 Type *VoidTy; 94 95 void getThunkType(FunctionType *FT, AttributeList AttrList, 96 Arm64ECThunkType TT, raw_ostream &Out, 97 FunctionType *&Arm64Ty, FunctionType *&X64Ty, 98 SmallVector<ThunkArgTranslation> &ArgTranslations); 99 void getThunkRetType(FunctionType *FT, AttributeList AttrList, 100 raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy, 101 SmallVectorImpl<Type *> &Arm64ArgTypes, 102 SmallVectorImpl<Type *> &X64ArgTypes, 103 SmallVector<ThunkArgTranslation> &ArgTranslations, 104 bool &HasSretPtr); 105 void getThunkArgTypes(FunctionType *FT, AttributeList AttrList, 106 Arm64ECThunkType TT, raw_ostream &Out, 107 SmallVectorImpl<Type *> &Arm64ArgTypes, 108 SmallVectorImpl<Type *> &X64ArgTypes, 109 SmallVectorImpl<ThunkArgTranslation> &ArgTranslations, 110 bool HasSretPtr); 111 ThunkArgInfo canonicalizeThunkType(Type *T, Align Alignment, bool Ret, 112 uint64_t ArgSizeBytes, raw_ostream &Out); 113 }; 114 115 } // end anonymous namespace 116 117 void AArch64Arm64ECCallLowering::getThunkType( 118 FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT, 119 raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty, 120 SmallVector<ThunkArgTranslation> &ArgTranslations) { 121 Out << (TT == Arm64ECThunkType::Entry ? "$ientry_thunk$cdecl$" 122 : "$iexit_thunk$cdecl$"); 123 124 Type *Arm64RetTy; 125 Type *X64RetTy; 126 127 SmallVector<Type *> Arm64ArgTypes; 128 SmallVector<Type *> X64ArgTypes; 129 130 // The first argument to a thunk is the called function, stored in x9. 131 // For exit thunks, we pass the called function down to the emulator; 132 // for entry/guest exit thunks, we just call the Arm64 function directly. 133 if (TT == Arm64ECThunkType::Exit) 134 Arm64ArgTypes.push_back(PtrTy); 135 X64ArgTypes.push_back(PtrTy); 136 137 bool HasSretPtr = false; 138 getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes, 139 X64ArgTypes, ArgTranslations, HasSretPtr); 140 141 getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes, 142 ArgTranslations, HasSretPtr); 143 144 Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false); 145 146 X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false); 147 } 148 149 void AArch64Arm64ECCallLowering::getThunkArgTypes( 150 FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT, 151 raw_ostream &Out, SmallVectorImpl<Type *> &Arm64ArgTypes, 152 SmallVectorImpl<Type *> &X64ArgTypes, 153 SmallVectorImpl<ThunkArgTranslation> &ArgTranslations, bool HasSretPtr) { 154 155 Out << "$"; 156 if (FT->isVarArg()) { 157 // We treat the variadic function's thunk as a normal function 158 // with the following type on the ARM side: 159 // rettype exitthunk( 160 // ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5) 161 // 162 // that can coverage all types of variadic function. 163 // x9 is similar to normal exit thunk, store the called function. 164 // x0-x3 is the arguments be stored in registers. 165 // x4 is the address of the arguments on the stack. 166 // x5 is the size of the arguments on the stack. 167 // 168 // On the x64 side, it's the same except that x5 isn't set. 169 // 170 // If both the ARM and X64 sides are sret, there are only three 171 // arguments in registers. 172 // 173 // If the X64 side is sret, but the ARM side isn't, we pass an extra value 174 // to/from the X64 side, and let SelectionDAG transform it into a memory 175 // location. 176 Out << "varargs"; 177 178 // x0-x3 179 for (int i = HasSretPtr ? 1 : 0; i < 4; i++) { 180 Arm64ArgTypes.push_back(I64Ty); 181 X64ArgTypes.push_back(I64Ty); 182 ArgTranslations.push_back(ThunkArgTranslation::Direct); 183 } 184 185 // x4 186 Arm64ArgTypes.push_back(PtrTy); 187 X64ArgTypes.push_back(PtrTy); 188 ArgTranslations.push_back(ThunkArgTranslation::Direct); 189 // x5 190 Arm64ArgTypes.push_back(I64Ty); 191 if (TT != Arm64ECThunkType::Entry) { 192 // FIXME: x5 isn't actually used by the x64 side; revisit once we 193 // have proper isel for varargs 194 X64ArgTypes.push_back(I64Ty); 195 ArgTranslations.push_back(ThunkArgTranslation::Direct); 196 } 197 return; 198 } 199 200 unsigned I = 0; 201 if (HasSretPtr) 202 I++; 203 204 if (I == FT->getNumParams()) { 205 Out << "v"; 206 return; 207 } 208 209 for (unsigned E = FT->getNumParams(); I != E; ++I) { 210 #if 0 211 // FIXME: Need more information about argument size; see 212 // https://reviews.llvm.org/D132926 213 uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I); 214 Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne(); 215 #else 216 uint64_t ArgSizeBytes = 0; 217 Align ParamAlign = Align(); 218 #endif 219 auto [Arm64Ty, X64Ty, ArgTranslation] = 220 canonicalizeThunkType(FT->getParamType(I), ParamAlign, 221 /*Ret*/ false, ArgSizeBytes, Out); 222 Arm64ArgTypes.push_back(Arm64Ty); 223 X64ArgTypes.push_back(X64Ty); 224 ArgTranslations.push_back(ArgTranslation); 225 } 226 } 227 228 void AArch64Arm64ECCallLowering::getThunkRetType( 229 FunctionType *FT, AttributeList AttrList, raw_ostream &Out, 230 Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes, 231 SmallVectorImpl<Type *> &X64ArgTypes, 232 SmallVector<ThunkArgTranslation> &ArgTranslations, bool &HasSretPtr) { 233 Type *T = FT->getReturnType(); 234 #if 0 235 // FIXME: Need more information about argument size; see 236 // https://reviews.llvm.org/D132926 237 uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes(); 238 #else 239 int64_t ArgSizeBytes = 0; 240 #endif 241 if (T->isVoidTy()) { 242 if (FT->getNumParams()) { 243 Attribute SRetAttr0 = AttrList.getParamAttr(0, Attribute::StructRet); 244 Attribute InRegAttr0 = AttrList.getParamAttr(0, Attribute::InReg); 245 Attribute SRetAttr1, InRegAttr1; 246 if (FT->getNumParams() > 1) { 247 // Also check the second parameter (for class methods, the first 248 // parameter is "this", and the second parameter is the sret pointer.) 249 // It doesn't matter which one is sret. 250 SRetAttr1 = AttrList.getParamAttr(1, Attribute::StructRet); 251 InRegAttr1 = AttrList.getParamAttr(1, Attribute::InReg); 252 } 253 if ((SRetAttr0.isValid() && InRegAttr0.isValid()) || 254 (SRetAttr1.isValid() && InRegAttr1.isValid())) { 255 // sret+inreg indicates a call that returns a C++ class value. This is 256 // actually equivalent to just passing and returning a void* pointer 257 // as the first or second argument. Translate it that way, instead of 258 // trying to model "inreg" in the thunk's calling convention; this 259 // simplfies the rest of the code, and matches MSVC mangling. 260 Out << "i8"; 261 Arm64RetTy = I64Ty; 262 X64RetTy = I64Ty; 263 return; 264 } 265 if (SRetAttr0.isValid()) { 266 // FIXME: Sanity-check the sret type; if it's an integer or pointer, 267 // we'll get screwy mangling/codegen. 268 // FIXME: For large struct types, mangle as an integer argument and 269 // integer return, so we can reuse more thunks, instead of "m" syntax. 270 // (MSVC mangles this case as an integer return with no argument, but 271 // that's a miscompile.) 272 Type *SRetType = SRetAttr0.getValueAsType(); 273 Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne(); 274 canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes, 275 Out); 276 Arm64RetTy = VoidTy; 277 X64RetTy = VoidTy; 278 Arm64ArgTypes.push_back(FT->getParamType(0)); 279 X64ArgTypes.push_back(FT->getParamType(0)); 280 ArgTranslations.push_back(ThunkArgTranslation::Direct); 281 HasSretPtr = true; 282 return; 283 } 284 } 285 286 Out << "v"; 287 Arm64RetTy = VoidTy; 288 X64RetTy = VoidTy; 289 return; 290 } 291 292 auto info = 293 canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out); 294 Arm64RetTy = info.Arm64Ty; 295 X64RetTy = info.X64Ty; 296 if (X64RetTy->isPointerTy()) { 297 // If the X64 type is canonicalized to a pointer, that means it's 298 // passed/returned indirectly. For a return value, that means it's an 299 // sret pointer. 300 X64ArgTypes.push_back(X64RetTy); 301 X64RetTy = VoidTy; 302 } 303 } 304 305 ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType( 306 Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes, 307 raw_ostream &Out) { 308 309 auto direct = [](Type *T) { 310 return ThunkArgInfo{T, T, ThunkArgTranslation::Direct}; 311 }; 312 313 auto bitcast = [this](Type *Arm64Ty, uint64_t SizeInBytes) { 314 return ThunkArgInfo{Arm64Ty, 315 llvm::Type::getIntNTy(M->getContext(), SizeInBytes * 8), 316 ThunkArgTranslation::Bitcast}; 317 }; 318 319 auto pointerIndirection = [this](Type *Arm64Ty) { 320 return ThunkArgInfo{Arm64Ty, PtrTy, 321 ThunkArgTranslation::PointerIndirection}; 322 }; 323 324 if (T->isFloatTy()) { 325 Out << "f"; 326 return direct(T); 327 } 328 329 if (T->isDoubleTy()) { 330 Out << "d"; 331 return direct(T); 332 } 333 334 if (T->isFloatingPointTy()) { 335 report_fatal_error( 336 "Only 32 and 64 bit floating points are supported for ARM64EC thunks"); 337 } 338 339 auto &DL = M->getDataLayout(); 340 341 if (auto *StructTy = dyn_cast<StructType>(T)) 342 if (StructTy->getNumElements() == 1) 343 T = StructTy->getElementType(0); 344 345 if (T->isArrayTy()) { 346 Type *ElementTy = T->getArrayElementType(); 347 uint64_t ElementCnt = T->getArrayNumElements(); 348 uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8; 349 uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes; 350 if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) { 351 Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes; 352 if (Alignment.value() >= 16 && !Ret) 353 Out << "a" << Alignment.value(); 354 if (TotalSizeBytes <= 8) { 355 // Arm64 returns small structs of float/double in float registers; 356 // X64 uses RAX. 357 return bitcast(T, TotalSizeBytes); 358 } else { 359 // Struct is passed directly on Arm64, but indirectly on X64. 360 return pointerIndirection(T); 361 } 362 } else if (T->isFloatingPointTy()) { 363 report_fatal_error("Only 32 and 64 bit floating points are supported for " 364 "ARM64EC thunks"); 365 } 366 } 367 368 if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) { 369 Out << "i8"; 370 return direct(I64Ty); 371 } 372 373 unsigned TypeSize = ArgSizeBytes; 374 if (TypeSize == 0) 375 TypeSize = DL.getTypeSizeInBits(T) / 8; 376 Out << "m"; 377 if (TypeSize != 4) 378 Out << TypeSize; 379 if (Alignment.value() >= 16 && !Ret) 380 Out << "a" << Alignment.value(); 381 // FIXME: Try to canonicalize Arm64Ty more thoroughly? 382 if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) { 383 // Pass directly in an integer register 384 return bitcast(T, TypeSize); 385 } else { 386 // Passed directly on Arm64, but indirectly on X64. 387 return pointerIndirection(T); 388 } 389 } 390 391 // This function builds the "exit thunk", a function which translates 392 // arguments and return values when calling x64 code from AArch64 code. 393 Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT, 394 AttributeList Attrs) { 395 SmallString<256> ExitThunkName; 396 llvm::raw_svector_ostream ExitThunkStream(ExitThunkName); 397 FunctionType *Arm64Ty, *X64Ty; 398 SmallVector<ThunkArgTranslation> ArgTranslations; 399 getThunkType(FT, Attrs, Arm64ECThunkType::Exit, ExitThunkStream, Arm64Ty, 400 X64Ty, ArgTranslations); 401 if (Function *F = M->getFunction(ExitThunkName)) 402 return F; 403 404 Function *F = Function::Create(Arm64Ty, GlobalValue::LinkOnceODRLinkage, 0, 405 ExitThunkName, M); 406 F->setCallingConv(CallingConv::ARM64EC_Thunk_Native); 407 F->setSection(".wowthk$aa"); 408 F->setComdat(M->getOrInsertComdat(ExitThunkName)); 409 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.) 410 F->addFnAttr("frame-pointer", "all"); 411 // Only copy sret from the first argument. For C++ instance methods, clang can 412 // stick an sret marking on a later argument, but it doesn't actually affect 413 // the ABI, so we can omit it. This avoids triggering a verifier assertion. 414 if (FT->getNumParams()) { 415 auto SRet = Attrs.getParamAttr(0, Attribute::StructRet); 416 auto InReg = Attrs.getParamAttr(0, Attribute::InReg); 417 if (SRet.isValid() && !InReg.isValid()) 418 F->addParamAttr(1, SRet); 419 } 420 // FIXME: Copy anything other than sret? Shouldn't be necessary for normal 421 // C ABI, but might show up in other cases. 422 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F); 423 IRBuilder<> IRB(BB); 424 Value *CalleePtr = 425 M->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy); 426 Value *Callee = IRB.CreateLoad(PtrTy, CalleePtr); 427 auto &DL = M->getDataLayout(); 428 SmallVector<Value *> Args; 429 430 // Pass the called function in x9. 431 auto X64TyOffset = 1; 432 Args.push_back(F->arg_begin()); 433 434 Type *RetTy = Arm64Ty->getReturnType(); 435 if (RetTy != X64Ty->getReturnType()) { 436 // If the return type is an array or struct, translate it. Values of size 437 // 8 or less go into RAX; bigger values go into memory, and we pass a 438 // pointer. 439 if (DL.getTypeStoreSize(RetTy) > 8) { 440 Args.push_back(IRB.CreateAlloca(RetTy)); 441 X64TyOffset++; 442 } 443 } 444 445 for (auto [Arg, X64ArgType, ArgTranslation] : llvm::zip_equal( 446 make_range(F->arg_begin() + 1, F->arg_end()), 447 make_range(X64Ty->param_begin() + X64TyOffset, X64Ty->param_end()), 448 ArgTranslations)) { 449 // Translate arguments from AArch64 calling convention to x86 calling 450 // convention. 451 // 452 // For simple types, we don't need to do any translation: they're 453 // represented the same way. (Implicit sign extension is not part of 454 // either convention.) 455 // 456 // The big thing we have to worry about is struct types... but 457 // fortunately AArch64 clang is pretty friendly here: the cases that need 458 // translation are always passed as a struct or array. (If we run into 459 // some cases where this doesn't work, we can teach clang to mark it up 460 // with an attribute.) 461 // 462 // The first argument is the called function, stored in x9. 463 if (ArgTranslation != ThunkArgTranslation::Direct) { 464 Value *Mem = IRB.CreateAlloca(Arg.getType()); 465 IRB.CreateStore(&Arg, Mem); 466 if (ArgTranslation == ThunkArgTranslation::Bitcast) { 467 Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType())); 468 Args.push_back(IRB.CreateLoad(IntTy, IRB.CreateBitCast(Mem, PtrTy))); 469 } else { 470 assert(ArgTranslation == ThunkArgTranslation::PointerIndirection); 471 Args.push_back(Mem); 472 } 473 } else { 474 Args.push_back(&Arg); 475 } 476 assert(Args.back()->getType() == X64ArgType); 477 } 478 // FIXME: Transfer necessary attributes? sret? anything else? 479 480 Callee = IRB.CreateBitCast(Callee, PtrTy); 481 CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args); 482 Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64); 483 484 Value *RetVal = Call; 485 if (RetTy != X64Ty->getReturnType()) { 486 // If we rewrote the return type earlier, convert the return value to 487 // the proper type. 488 if (DL.getTypeStoreSize(RetTy) > 8) { 489 RetVal = IRB.CreateLoad(RetTy, Args[1]); 490 } else { 491 Value *CastAlloca = IRB.CreateAlloca(RetTy); 492 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy)); 493 RetVal = IRB.CreateLoad(RetTy, CastAlloca); 494 } 495 } 496 497 if (RetTy->isVoidTy()) 498 IRB.CreateRetVoid(); 499 else 500 IRB.CreateRet(RetVal); 501 return F; 502 } 503 504 // This function builds the "entry thunk", a function which translates 505 // arguments and return values when calling AArch64 code from x64 code. 506 Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) { 507 SmallString<256> EntryThunkName; 508 llvm::raw_svector_ostream EntryThunkStream(EntryThunkName); 509 FunctionType *Arm64Ty, *X64Ty; 510 SmallVector<ThunkArgTranslation> ArgTranslations; 511 getThunkType(F->getFunctionType(), F->getAttributes(), 512 Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty, 513 ArgTranslations); 514 if (Function *F = M->getFunction(EntryThunkName)) 515 return F; 516 517 Function *Thunk = Function::Create(X64Ty, GlobalValue::LinkOnceODRLinkage, 0, 518 EntryThunkName, M); 519 Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64); 520 Thunk->setSection(".wowthk$aa"); 521 Thunk->setComdat(M->getOrInsertComdat(EntryThunkName)); 522 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.) 523 Thunk->addFnAttr("frame-pointer", "all"); 524 525 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk); 526 IRBuilder<> IRB(BB); 527 528 Type *RetTy = Arm64Ty->getReturnType(); 529 Type *X64RetType = X64Ty->getReturnType(); 530 531 bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy(); 532 unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1; 533 unsigned PassthroughArgSize = 534 (F->isVarArg() ? 5 : Thunk->arg_size()) - ThunkArgOffset; 535 assert(ArgTranslations.size() == (F->isVarArg() ? 5 : PassthroughArgSize)); 536 537 // Translate arguments to call. 538 SmallVector<Value *> Args; 539 for (unsigned i = 0; i != PassthroughArgSize; ++i) { 540 Value *Arg = Thunk->getArg(i + ThunkArgOffset); 541 Type *ArgTy = Arm64Ty->getParamType(i); 542 ThunkArgTranslation ArgTranslation = ArgTranslations[i]; 543 if (ArgTranslation != ThunkArgTranslation::Direct) { 544 // Translate array/struct arguments to the expected type. 545 if (ArgTranslation == ThunkArgTranslation::Bitcast) { 546 Value *CastAlloca = IRB.CreateAlloca(ArgTy); 547 IRB.CreateStore(Arg, IRB.CreateBitCast(CastAlloca, PtrTy)); 548 Arg = IRB.CreateLoad(ArgTy, CastAlloca); 549 } else { 550 assert(ArgTranslation == ThunkArgTranslation::PointerIndirection); 551 Arg = IRB.CreateLoad(ArgTy, IRB.CreateBitCast(Arg, PtrTy)); 552 } 553 } 554 assert(Arg->getType() == ArgTy); 555 Args.push_back(Arg); 556 } 557 558 if (F->isVarArg()) { 559 // The 5th argument to variadic entry thunks is used to model the x64 sp 560 // which is passed to the thunk in x4, this can be passed to the callee as 561 // the variadic argument start address after skipping over the 32 byte 562 // shadow store. 563 564 // The EC thunk CC will assign any argument marked as InReg to x4. 565 Thunk->addParamAttr(5, Attribute::InReg); 566 Value *Arg = Thunk->getArg(5); 567 Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20)); 568 Args.push_back(Arg); 569 570 // Pass in a zero variadic argument size (in x5). 571 Args.push_back(IRB.getInt64(0)); 572 } 573 574 // Call the function passed to the thunk. 575 Value *Callee = Thunk->getArg(0); 576 Callee = IRB.CreateBitCast(Callee, PtrTy); 577 CallInst *Call = IRB.CreateCall(Arm64Ty, Callee, Args); 578 579 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet); 580 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg); 581 if (SRetAttr.isValid() && !InRegAttr.isValid()) { 582 Thunk->addParamAttr(1, SRetAttr); 583 Call->addParamAttr(0, SRetAttr); 584 } 585 586 Value *RetVal = Call; 587 if (TransformDirectToSRet) { 588 IRB.CreateStore(RetVal, IRB.CreateBitCast(Thunk->getArg(1), PtrTy)); 589 } else if (X64RetType != RetTy) { 590 Value *CastAlloca = IRB.CreateAlloca(X64RetType); 591 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy)); 592 RetVal = IRB.CreateLoad(X64RetType, CastAlloca); 593 } 594 595 // Return to the caller. Note that the isel has code to translate this 596 // "ret" to a tail call to __os_arm64x_dispatch_ret. (Alternatively, we 597 // could emit a tail call here, but that would require a dedicated calling 598 // convention, which seems more complicated overall.) 599 if (X64RetType->isVoidTy()) 600 IRB.CreateRetVoid(); 601 else 602 IRB.CreateRet(RetVal); 603 604 return Thunk; 605 } 606 607 // Builds the "guest exit thunk", a helper to call a function which may or may 608 // not be an exit thunk. (We optimistically assume non-dllimport function 609 // declarations refer to functions defined in AArch64 code; if the linker 610 // can't prove that, we use this routine instead.) 611 Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) { 612 llvm::raw_null_ostream NullThunkName; 613 FunctionType *Arm64Ty, *X64Ty; 614 SmallVector<ThunkArgTranslation> ArgTranslations; 615 getThunkType(F->getFunctionType(), F->getAttributes(), 616 Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty, 617 ArgTranslations); 618 auto MangledName = getArm64ECMangledFunctionName(F->getName().str()); 619 assert(MangledName && "Can't guest exit to function that's already native"); 620 std::string ThunkName = *MangledName; 621 if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) { 622 ThunkName.insert(ThunkName.find("@"), "$exit_thunk"); 623 } else { 624 ThunkName.append("$exit_thunk"); 625 } 626 Function *GuestExit = 627 Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M); 628 GuestExit->setComdat(M->getOrInsertComdat(ThunkName)); 629 GuestExit->setSection(".wowthk$aa"); 630 GuestExit->setMetadata( 631 "arm64ec_unmangled_name", 632 MDNode::get(M->getContext(), 633 MDString::get(M->getContext(), F->getName()))); 634 GuestExit->setMetadata( 635 "arm64ec_ecmangled_name", 636 MDNode::get(M->getContext(), 637 MDString::get(M->getContext(), *MangledName))); 638 F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {})); 639 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit); 640 IRBuilder<> B(BB); 641 642 // Load the global symbol as a pointer to the check function. 643 Value *GuardFn; 644 if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf")) 645 GuardFn = GuardFnCFGlobal; 646 else 647 GuardFn = GuardFnGlobal; 648 LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn); 649 650 // Create new call instruction. The CFGuard check should always be a call, 651 // even if the original CallBase is an Invoke or CallBr instruction. 652 Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes()); 653 CallInst *GuardCheck = B.CreateCall( 654 GuardFnType, GuardCheckLoad, 655 {B.CreateBitCast(F, B.getPtrTy()), B.CreateBitCast(Thunk, B.getPtrTy())}); 656 657 // Ensure that the first argument is passed in the correct register. 658 GuardCheck->setCallingConv(CallingConv::CFGuard_Check); 659 660 Value *GuardRetVal = B.CreateBitCast(GuardCheck, PtrTy); 661 SmallVector<Value *> Args; 662 for (Argument &Arg : GuestExit->args()) 663 Args.push_back(&Arg); 664 CallInst *Call = B.CreateCall(Arm64Ty, GuardRetVal, Args); 665 Call->setTailCallKind(llvm::CallInst::TCK_MustTail); 666 667 if (Call->getType()->isVoidTy()) 668 B.CreateRetVoid(); 669 else 670 B.CreateRet(Call); 671 672 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet); 673 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg); 674 if (SRetAttr.isValid() && !InRegAttr.isValid()) { 675 GuestExit->addParamAttr(0, SRetAttr); 676 Call->addParamAttr(0, SRetAttr); 677 } 678 679 return GuestExit; 680 } 681 682 Function * 683 AArch64Arm64ECCallLowering::buildPatchableThunk(GlobalAlias *UnmangledAlias, 684 GlobalAlias *MangledAlias) { 685 llvm::raw_null_ostream NullThunkName; 686 FunctionType *Arm64Ty, *X64Ty; 687 Function *F = cast<Function>(MangledAlias->getAliasee()); 688 SmallVector<ThunkArgTranslation> ArgTranslations; 689 getThunkType(F->getFunctionType(), F->getAttributes(), 690 Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty, 691 ArgTranslations); 692 std::string ThunkName(MangledAlias->getName()); 693 if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) { 694 ThunkName.insert(ThunkName.find("@"), "$hybpatch_thunk"); 695 } else { 696 ThunkName.append("$hybpatch_thunk"); 697 } 698 699 Function *GuestExit = 700 Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M); 701 GuestExit->setComdat(M->getOrInsertComdat(ThunkName)); 702 GuestExit->setSection(".wowthk$aa"); 703 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit); 704 IRBuilder<> B(BB); 705 706 // Load the global symbol as a pointer to the check function. 707 LoadInst *DispatchLoad = B.CreateLoad(DispatchFnPtrType, DispatchFnGlobal); 708 709 // Create new dispatch call instruction. 710 Function *ExitThunk = 711 buildExitThunk(F->getFunctionType(), F->getAttributes()); 712 CallInst *Dispatch = 713 B.CreateCall(DispatchFnType, DispatchLoad, 714 {UnmangledAlias, ExitThunk, UnmangledAlias->getAliasee()}); 715 716 // Ensure that the first arguments are passed in the correct registers. 717 Dispatch->setCallingConv(CallingConv::CFGuard_Check); 718 719 Value *DispatchRetVal = B.CreateBitCast(Dispatch, PtrTy); 720 SmallVector<Value *> Args; 721 for (Argument &Arg : GuestExit->args()) 722 Args.push_back(&Arg); 723 CallInst *Call = B.CreateCall(Arm64Ty, DispatchRetVal, Args); 724 Call->setTailCallKind(llvm::CallInst::TCK_MustTail); 725 726 if (Call->getType()->isVoidTy()) 727 B.CreateRetVoid(); 728 else 729 B.CreateRet(Call); 730 731 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet); 732 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg); 733 if (SRetAttr.isValid() && !InRegAttr.isValid()) { 734 GuestExit->addParamAttr(0, SRetAttr); 735 Call->addParamAttr(0, SRetAttr); 736 } 737 738 MangledAlias->setAliasee(GuestExit); 739 return GuestExit; 740 } 741 742 // Lower an indirect call with inline code. 743 void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) { 744 assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() && 745 "Only applicable for Windows targets"); 746 747 IRBuilder<> B(CB); 748 Value *CalledOperand = CB->getCalledOperand(); 749 750 // If the indirect call is called within catchpad or cleanuppad, 751 // we need to copy "funclet" bundle of the call. 752 SmallVector<llvm::OperandBundleDef, 1> Bundles; 753 if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet)) 754 Bundles.push_back(OperandBundleDef(*Bundle)); 755 756 // Load the global symbol as a pointer to the check function. 757 Value *GuardFn; 758 if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf")) 759 GuardFn = GuardFnCFGlobal; 760 else 761 GuardFn = GuardFnGlobal; 762 LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn); 763 764 // Create new call instruction. The CFGuard check should always be a call, 765 // even if the original CallBase is an Invoke or CallBr instruction. 766 Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes()); 767 CallInst *GuardCheck = 768 B.CreateCall(GuardFnType, GuardCheckLoad, 769 {B.CreateBitCast(CalledOperand, B.getPtrTy()), 770 B.CreateBitCast(Thunk, B.getPtrTy())}, 771 Bundles); 772 773 // Ensure that the first argument is passed in the correct register. 774 GuardCheck->setCallingConv(CallingConv::CFGuard_Check); 775 776 Value *GuardRetVal = B.CreateBitCast(GuardCheck, CalledOperand->getType()); 777 CB->setCalledOperand(GuardRetVal); 778 } 779 780 bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) { 781 if (!GenerateThunks) 782 return false; 783 784 M = &Mod; 785 786 // Check if this module has the cfguard flag and read its value. 787 if (auto *MD = 788 mdconst::extract_or_null<ConstantInt>(M->getModuleFlag("cfguard"))) 789 cfguard_module_flag = MD->getZExtValue(); 790 791 PtrTy = PointerType::getUnqual(M->getContext()); 792 I64Ty = Type::getInt64Ty(M->getContext()); 793 VoidTy = Type::getVoidTy(M->getContext()); 794 795 GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false); 796 GuardFnPtrType = PointerType::get(GuardFnType, 0); 797 DispatchFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy, PtrTy}, false); 798 DispatchFnPtrType = PointerType::get(DispatchFnType, 0); 799 GuardFnCFGlobal = 800 M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType); 801 GuardFnGlobal = 802 M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType); 803 DispatchFnGlobal = 804 M->getOrInsertGlobal("__os_arm64x_dispatch_call", DispatchFnPtrType); 805 806 DenseMap<GlobalAlias *, GlobalAlias *> FnsMap; 807 SetVector<GlobalAlias *> PatchableFns; 808 809 for (Function &F : Mod) { 810 if (!F.hasFnAttribute(Attribute::HybridPatchable) || F.isDeclaration() || 811 F.hasLocalLinkage() || F.getName().ends_with("$hp_target")) 812 continue; 813 814 // Rename hybrid patchable functions and change callers to use a global 815 // alias instead. 816 if (std::optional<std::string> MangledName = 817 getArm64ECMangledFunctionName(F.getName().str())) { 818 std::string OrigName(F.getName()); 819 F.setName(MangledName.value() + "$hp_target"); 820 821 // The unmangled symbol is a weak alias to an undefined symbol with the 822 // "EXP+" prefix. This undefined symbol is resolved by the linker by 823 // creating an x86 thunk that jumps back to the actual EC target. Since we 824 // can't represent that in IR, we create an alias to the target instead. 825 // The "EXP+" symbol is set as metadata, which is then used by 826 // emitGlobalAlias to emit the right alias. 827 auto *A = 828 GlobalAlias::create(GlobalValue::LinkOnceODRLinkage, OrigName, &F); 829 F.replaceAllUsesWith(A); 830 F.setMetadata("arm64ec_exp_name", 831 MDNode::get(M->getContext(), 832 MDString::get(M->getContext(), 833 "EXP+" + MangledName.value()))); 834 A->setAliasee(&F); 835 836 if (F.hasDLLExportStorageClass()) { 837 A->setDLLStorageClass(GlobalValue::DLLExportStorageClass); 838 F.setDLLStorageClass(GlobalValue::DefaultStorageClass); 839 } 840 841 FnsMap[A] = GlobalAlias::create(GlobalValue::LinkOnceODRLinkage, 842 MangledName.value(), &F); 843 PatchableFns.insert(A); 844 } 845 } 846 847 SetVector<GlobalValue *> DirectCalledFns; 848 for (Function &F : Mod) 849 if (!F.isDeclaration() && 850 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native && 851 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) 852 processFunction(F, DirectCalledFns, FnsMap); 853 854 struct ThunkInfo { 855 Constant *Src; 856 Constant *Dst; 857 Arm64ECThunkType Kind; 858 }; 859 SmallVector<ThunkInfo> ThunkMapping; 860 for (Function &F : Mod) { 861 if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) && 862 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native && 863 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) { 864 if (!F.hasComdat()) 865 F.setComdat(Mod.getOrInsertComdat(F.getName())); 866 ThunkMapping.push_back( 867 {&F, buildEntryThunk(&F), Arm64ECThunkType::Entry}); 868 } 869 } 870 for (GlobalValue *O : DirectCalledFns) { 871 auto GA = dyn_cast<GlobalAlias>(O); 872 auto F = dyn_cast<Function>(GA ? GA->getAliasee() : O); 873 ThunkMapping.push_back( 874 {O, buildExitThunk(F->getFunctionType(), F->getAttributes()), 875 Arm64ECThunkType::Exit}); 876 if (!GA && !F->hasDLLImportStorageClass()) 877 ThunkMapping.push_back( 878 {buildGuestExitThunk(F), F, Arm64ECThunkType::GuestExit}); 879 } 880 for (GlobalAlias *A : PatchableFns) { 881 Function *Thunk = buildPatchableThunk(A, FnsMap[A]); 882 ThunkMapping.push_back({Thunk, A, Arm64ECThunkType::GuestExit}); 883 } 884 885 if (!ThunkMapping.empty()) { 886 SmallVector<Constant *> ThunkMappingArrayElems; 887 for (ThunkInfo &Thunk : ThunkMapping) { 888 ThunkMappingArrayElems.push_back(ConstantStruct::getAnon( 889 {ConstantExpr::getBitCast(Thunk.Src, PtrTy), 890 ConstantExpr::getBitCast(Thunk.Dst, PtrTy), 891 ConstantInt::get(M->getContext(), APInt(32, uint8_t(Thunk.Kind)))})); 892 } 893 Constant *ThunkMappingArray = ConstantArray::get( 894 llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(), 895 ThunkMappingArrayElems.size()), 896 ThunkMappingArrayElems); 897 new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false, 898 GlobalValue::ExternalLinkage, ThunkMappingArray, 899 "llvm.arm64ec.symbolmap"); 900 } 901 902 return true; 903 } 904 905 bool AArch64Arm64ECCallLowering::processFunction( 906 Function &F, SetVector<GlobalValue *> &DirectCalledFns, 907 DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap) { 908 SmallVector<CallBase *, 8> IndirectCalls; 909 910 // For ARM64EC targets, a function definition's name is mangled differently 911 // from the normal symbol. We currently have no representation of this sort 912 // of symbol in IR, so we change the name to the mangled name, then store 913 // the unmangled name as metadata. Later passes that need the unmangled 914 // name (emitting the definition) can grab it from the metadata. 915 // 916 // FIXME: Handle functions with weak linkage? 917 if (!F.hasLocalLinkage() || F.hasAddressTaken()) { 918 if (std::optional<std::string> MangledName = 919 getArm64ECMangledFunctionName(F.getName().str())) { 920 F.setMetadata("arm64ec_unmangled_name", 921 MDNode::get(M->getContext(), 922 MDString::get(M->getContext(), F.getName()))); 923 if (F.hasComdat() && F.getComdat()->getName() == F.getName()) { 924 Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value()); 925 SmallVector<GlobalObject *> ComdatUsers = 926 to_vector(F.getComdat()->getUsers()); 927 for (GlobalObject *User : ComdatUsers) 928 User->setComdat(MangledComdat); 929 } 930 F.setName(MangledName.value()); 931 } 932 } 933 934 // Iterate over the instructions to find all indirect call/invoke/callbr 935 // instructions. Make a separate list of pointers to indirect 936 // call/invoke/callbr instructions because the original instructions will be 937 // deleted as the checks are added. 938 for (BasicBlock &BB : F) { 939 for (Instruction &I : BB) { 940 auto *CB = dyn_cast<CallBase>(&I); 941 if (!CB || CB->getCallingConv() == CallingConv::ARM64EC_Thunk_X64 || 942 CB->isInlineAsm()) 943 continue; 944 945 // We need to instrument any call that isn't directly calling an 946 // ARM64 function. 947 // 948 // FIXME: getCalledFunction() fails if there's a bitcast (e.g. 949 // unprototyped functions in C) 950 if (Function *F = CB->getCalledFunction()) { 951 if (!LowerDirectToIndirect || F->hasLocalLinkage() || 952 F->isIntrinsic() || !F->isDeclaration()) 953 continue; 954 955 DirectCalledFns.insert(F); 956 continue; 957 } 958 959 // Use mangled global alias for direct calls to patchable functions. 960 if (GlobalAlias *A = dyn_cast<GlobalAlias>(CB->getCalledOperand())) { 961 auto I = FnsMap.find(A); 962 if (I != FnsMap.end()) { 963 CB->setCalledOperand(I->second); 964 DirectCalledFns.insert(I->first); 965 continue; 966 } 967 } 968 969 IndirectCalls.push_back(CB); 970 ++Arm64ECCallsLowered; 971 } 972 } 973 974 if (IndirectCalls.empty()) 975 return false; 976 977 for (CallBase *CB : IndirectCalls) 978 lowerCall(CB); 979 980 return true; 981 } 982 983 char AArch64Arm64ECCallLowering::ID = 0; 984 INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering", 985 "AArch64Arm64ECCallLowering", false, false) 986 987 ModulePass *llvm::createAArch64Arm64ECCallLoweringPass() { 988 return new AArch64Arm64ECCallLowering; 989 } 990